|
|
|
@ -38,7 +38,8 @@ from functools import partial
|
|
|
|
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
|
|
|
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
|
|
|
|
|
from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\
|
|
|
|
|
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
@ -107,6 +108,16 @@ default_cfgs = {
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'resnetv2_50d': _cfg(
|
|
|
|
|
interpolation='bicubic', first_conv='stem.conv1'),
|
|
|
|
|
'resnetv2_50t': _cfg(
|
|
|
|
|
interpolation='bicubic', first_conv='stem.conv1'),
|
|
|
|
|
'resnetv2_101': _cfg(
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'resnetv2_101d': _cfg(
|
|
|
|
|
interpolation='bicubic', first_conv='stem.conv1'),
|
|
|
|
|
'resnetv2_152': _cfg(
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'resnetv2_152d': _cfg(
|
|
|
|
|
interpolation='bicubic', first_conv='stem.conv1'),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -152,8 +163,8 @@ class PreActBottleneck(nn.Module):
|
|
|
|
|
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def zero_init_last_bn(self):
|
|
|
|
|
nn.init.zeros_(self.norm3.weight)
|
|
|
|
|
def zero_init_last(self):
|
|
|
|
|
nn.init.zeros_(self.conv3.weight)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x_preact = self.norm1(x)
|
|
|
|
@ -201,7 +212,7 @@ class Bottleneck(nn.Module):
|
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
|
|
|
|
self.act3 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
def zero_init_last_bn(self):
|
|
|
|
|
def zero_init_last(self):
|
|
|
|
|
nn.init.zeros_(self.norm3.weight)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
@ -284,17 +295,20 @@ def create_resnetv2_stem(
|
|
|
|
|
in_chs, out_chs=64, stem_type='', preact=True,
|
|
|
|
|
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
|
|
|
|
|
stem = OrderedDict()
|
|
|
|
|
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
|
|
|
|
|
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
|
|
|
|
|
|
|
|
|
|
# NOTE conv padding mode can be changed by overriding the conv_layer def
|
|
|
|
|
if 'deep' in stem_type:
|
|
|
|
|
if any([s in stem_type for s in ('deep', 'tiered')]):
|
|
|
|
|
# A 3 deep 3x3 conv stack as in ResNet V1D models
|
|
|
|
|
mid_chs = out_chs // 2
|
|
|
|
|
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
|
|
|
|
stem['norm1'] = norm_layer(mid_chs)
|
|
|
|
|
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
|
|
|
|
stem['norm2'] = norm_layer(mid_chs)
|
|
|
|
|
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
|
|
|
|
|
if 'tiered' in stem_type:
|
|
|
|
|
stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
|
|
|
|
|
else:
|
|
|
|
|
stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets
|
|
|
|
|
stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2)
|
|
|
|
|
stem['norm1'] = norm_layer(stem_chs[0])
|
|
|
|
|
stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1)
|
|
|
|
|
stem['norm2'] = norm_layer(stem_chs[1])
|
|
|
|
|
stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1)
|
|
|
|
|
if not preact:
|
|
|
|
|
stem['norm3'] = norm_layer(out_chs)
|
|
|
|
|
else:
|
|
|
|
@ -326,7 +340,7 @@ class ResNetV2(nn.Module):
|
|
|
|
|
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
|
|
|
|
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
|
|
|
|
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
|
|
|
|
|
drop_rate=0., drop_path_rate=0., zero_init_last_bn=True):
|
|
|
|
|
drop_rate=0., drop_path_rate=0., zero_init_last=True):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
@ -364,10 +378,10 @@ class ResNetV2(nn.Module):
|
|
|
|
|
self.head = ClassifierHead(
|
|
|
|
|
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
|
|
|
|
|
|
|
|
|
self.init_weights(zero_init_last_bn=zero_init_last_bn)
|
|
|
|
|
self.init_weights(zero_init_last=zero_init_last)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn=True):
|
|
|
|
|
named_apply(partial(_init_weights, zero_init_last_bn=zero_init_last_bn), self)
|
|
|
|
|
def init_weights(self, zero_init_last=True):
|
|
|
|
|
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore()
|
|
|
|
|
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
|
|
|
|
@ -393,7 +407,7 @@ class ResNetV2(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
|
|
|
|
|
def _init_weights(module: nn.Module, name: str = '', zero_init_last=True):
|
|
|
|
|
if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)):
|
|
|
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.01)
|
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
@ -404,8 +418,8 @@ def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
|
|
|
|
|
elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
|
|
|
|
|
nn.init.ones_(module.weight)
|
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
|
elif zero_init_last_bn and hasattr(module, 'zero_init_last_bn'):
|
|
|
|
|
module.zero_init_last_bn()
|
|
|
|
|
elif zero_init_last and hasattr(module, 'zero_init_last'):
|
|
|
|
|
module.zero_init_last()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@ -570,12 +584,68 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
|
|
|
|
|
def resnetv2_50(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
'resnetv2_50', pretrained=pretrained,
|
|
|
|
|
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, **kwargs)
|
|
|
|
|
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnetv2_50d(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
'resnetv2_50d', pretrained=pretrained,
|
|
|
|
|
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
|
|
|
|
|
stem_type='deep', avg_down=True, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnetv2_50t(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
'resnetv2_50t', pretrained=pretrained,
|
|
|
|
|
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
|
|
|
|
|
stem_type='tiered', avg_down=True, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnetv2_101(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
'resnetv2_101', pretrained=pretrained,
|
|
|
|
|
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnetv2_101d(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
'resnetv2_101d', pretrained=pretrained,
|
|
|
|
|
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
|
|
|
|
|
stem_type='deep', avg_down=True, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnetv2_152(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
'resnetv2_152', pretrained=pretrained,
|
|
|
|
|
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnetv2_152d(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
'resnetv2_152d', pretrained=pretrained,
|
|
|
|
|
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
|
|
|
|
|
stem_type='deep', avg_down=True, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @register_model
|
|
|
|
|
# def resnetv2_50ebd(pretrained=False, **kwargs):
|
|
|
|
|
# # FIXME for testing w/ TPU + PyTorch XLA
|
|
|
|
|
# return _create_resnetv2(
|
|
|
|
|
# 'resnetv2_50d', pretrained=pretrained,
|
|
|
|
|
# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d,
|
|
|
|
|
# stem_type='deep', avg_down=True, **kwargs)
|
|
|
|
|
#
|
|
|
|
|
#
|
|
|
|
|
# @register_model
|
|
|
|
|
# def resnetv2_50esd(pretrained=False, **kwargs):
|
|
|
|
|
# # FIXME for testing w/ TPU + PyTorch XLA
|
|
|
|
|
# return _create_resnetv2(
|
|
|
|
|
# 'resnetv2_50d', pretrained=pretrained,
|
|
|
|
|
# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d,
|
|
|
|
|
# stem_type='deep', avg_down=True, **kwargs)
|
|
|
|
|