|
|
|
@ -12,7 +12,7 @@ import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
from .helpers import load_pretrained, adapt_model_from_file
|
|
|
|
|
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn
|
|
|
|
|
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
|
|
|
|
|
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
|
|
|
|
@ -118,6 +118,11 @@ default_cfgs = {
|
|
|
|
|
'ecaresnet101d_pruned': _cfg(
|
|
|
|
|
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'resnetblur18': _cfg(
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'resnetblur50': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth',
|
|
|
|
|
interpolation='bicubic')
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -131,7 +136,7 @@ class BasicBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
|
|
|
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
attn_layer=None, drop_block=None, drop_path=None):
|
|
|
|
|
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
|
|
|
|
super(BasicBlock, self).__init__()
|
|
|
|
|
|
|
|
|
|
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
|
|
|
@ -139,12 +144,15 @@ class BasicBlock(nn.Module):
|
|
|
|
|
first_planes = planes // reduce_first
|
|
|
|
|
outplanes = planes * self.expansion
|
|
|
|
|
first_dilation = first_dilation or dilation
|
|
|
|
|
use_aa = aa_layer is not None
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(
|
|
|
|
|
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
|
|
|
|
|
inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
|
|
|
|
|
dilation=first_dilation, bias=False)
|
|
|
|
|
self.bn1 = norm_layer(first_planes)
|
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
self.aa = aa_layer(channels=first_planes) if stride == 2 and use_aa else None
|
|
|
|
|
|
|
|
|
|
self.conv2 = nn.Conv2d(
|
|
|
|
|
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
|
|
|
|
self.bn2 = norm_layer(outplanes)
|
|
|
|
@ -169,6 +177,8 @@ class BasicBlock(nn.Module):
|
|
|
|
|
if self.drop_block is not None:
|
|
|
|
|
x = self.drop_block(x)
|
|
|
|
|
x = self.act1(x)
|
|
|
|
|
if self.aa is not None:
|
|
|
|
|
x = self.aa(x)
|
|
|
|
|
|
|
|
|
|
x = self.conv2(x)
|
|
|
|
|
x = self.bn2(x)
|
|
|
|
@ -195,22 +205,26 @@ class Bottleneck(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
|
|
|
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
attn_layer=None, drop_block=None, drop_path=None):
|
|
|
|
|
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
|
|
|
|
super(Bottleneck, self).__init__()
|
|
|
|
|
|
|
|
|
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
|
|
|
|
first_planes = width // reduce_first
|
|
|
|
|
outplanes = planes * self.expansion
|
|
|
|
|
first_dilation = first_dilation or dilation
|
|
|
|
|
use_aa = aa_layer is not None
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
|
|
|
|
self.bn1 = norm_layer(first_planes)
|
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
self.conv2 = nn.Conv2d(
|
|
|
|
|
first_planes, width, kernel_size=3, stride=stride,
|
|
|
|
|
first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
|
|
|
|
|
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
|
|
|
|
self.bn2 = norm_layer(width)
|
|
|
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
|
self.aa = aa_layer(channels=width) if stride == 2 and use_aa else None
|
|
|
|
|
|
|
|
|
|
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
|
|
|
|
self.bn3 = norm_layer(outplanes)
|
|
|
|
|
|
|
|
|
@ -240,6 +254,8 @@ class Bottleneck(nn.Module):
|
|
|
|
|
if self.drop_block is not None:
|
|
|
|
|
x = self.drop_block(x)
|
|
|
|
|
x = self.act2(x)
|
|
|
|
|
if self.aa is not None:
|
|
|
|
|
x = self.aa(x)
|
|
|
|
|
|
|
|
|
|
x = self.conv3(x)
|
|
|
|
|
x = self.bn3(x)
|
|
|
|
@ -353,8 +369,9 @@ class ResNet(nn.Module):
|
|
|
|
|
Whether to use average pooling for projection skip connection between stages/downsample.
|
|
|
|
|
output_stride : int, default 32
|
|
|
|
|
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
|
|
|
|
|
act_layer : class, activation layer
|
|
|
|
|
norm_layer : class, normalization layer
|
|
|
|
|
act_layer : nn.Module, activation layer
|
|
|
|
|
norm_layer : nn.Module, normalization layer
|
|
|
|
|
aa_layer : nn.Module, anti-aliasing layer
|
|
|
|
|
drop_rate : float, default 0.
|
|
|
|
|
Dropout probability before classifier, for training
|
|
|
|
|
global_pool : str, default 'avg'
|
|
|
|
@ -363,7 +380,7 @@ class ResNet(nn.Module):
|
|
|
|
|
def __init__(self, block, layers, num_classes=1000, in_chans=3,
|
|
|
|
|
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
|
|
|
|
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
|
|
|
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
|
|
|
|
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
|
|
|
|
|
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
|
|
|
|
|
block_args = block_args or dict()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
@ -393,7 +410,14 @@ class ResNet(nn.Module):
|
|
|
|
|
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
|
|
|
|
self.bn1 = norm_layer(self.inplanes)
|
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
|
# Stem Pooling
|
|
|
|
|
if aa_layer is not None:
|
|
|
|
|
self.maxpool = nn.Sequential(*[
|
|
|
|
|
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
|
|
|
|
aa_layer(channels=self.inplanes, stride=2)
|
|
|
|
|
])
|
|
|
|
|
else:
|
|
|
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
|
|
|
|
|
|
# Feature Blocks
|
|
|
|
|
dp = DropPath(drop_path_rate) if drop_path_rate else None
|
|
|
|
@ -410,7 +434,7 @@ class ResNet(nn.Module):
|
|
|
|
|
assert output_stride == 32
|
|
|
|
|
layer_args = list(zip(channels, layers, strides, dilations))
|
|
|
|
|
layer_kwargs = dict(
|
|
|
|
|
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
|
|
|
|
|
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
|
|
|
|
|
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
|
|
|
|
|
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
|
|
|
|
|
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
|
|
|
|
@ -1114,3 +1138,29 @@ def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwarg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-18 model with blur anti-aliasing
|
|
|
|
|
"""
|
|
|
|
|
default_cfg = default_cfgs['resnetblur18']
|
|
|
|
|
model = ResNet(
|
|
|
|
|
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-50 model with blur anti-aliasing
|
|
|
|
|
"""
|
|
|
|
|
default_cfg = default_cfgs['resnetblur50']
|
|
|
|
|
model = ResNet(
|
|
|
|
|
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|