Fix BatchNorm for ResNetV2 non GN models, add more ResNetV2 model defs for future experimentation, fix zero_init of last residual for pre-act.

pull/729/head
Ross Wightman 3 years ago
parent 02aaa785b9
commit e8045e712f

@ -38,7 +38,8 @@ from functools import partial
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .registry import register_model
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
def _cfg(url='', **kwargs):
@ -107,6 +108,16 @@ default_cfgs = {
interpolation='bicubic'),
'resnetv2_50d': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50t': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_101': _cfg(
interpolation='bicubic'),
'resnetv2_101d': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_152': _cfg(
interpolation='bicubic'),
'resnetv2_152d': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
}
@ -152,8 +163,8 @@ class PreActBottleneck(nn.Module):
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last_bn(self):
nn.init.zeros_(self.norm3.weight)
def zero_init_last(self):
nn.init.zeros_(self.conv3.weight)
def forward(self, x):
x_preact = self.norm1(x)
@ -201,7 +212,7 @@ class Bottleneck(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.act3 = act_layer(inplace=True)
def zero_init_last_bn(self):
def zero_init_last(self):
nn.init.zeros_(self.norm3.weight)
def forward(self, x):
@ -284,17 +295,20 @@ def create_resnetv2_stem(
in_chs, out_chs=64, stem_type='', preact=True,
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
stem = OrderedDict()
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
# NOTE conv padding mode can be changed by overriding the conv_layer def
if 'deep' in stem_type:
if any([s in stem_type for s in ('deep', 'tiered')]):
# A 3 deep 3x3 conv stack as in ResNet V1D models
mid_chs = out_chs // 2
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
stem['norm1'] = norm_layer(mid_chs)
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
stem['norm2'] = norm_layer(mid_chs)
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
if 'tiered' in stem_type:
stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
else:
stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets
stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2)
stem['norm1'] = norm_layer(stem_chs[0])
stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1)
stem['norm2'] = norm_layer(stem_chs[1])
stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1)
if not preact:
stem['norm3'] = norm_layer(out_chs)
else:
@ -326,7 +340,7 @@ class ResNetV2(nn.Module):
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0., zero_init_last_bn=True):
drop_rate=0., drop_path_rate=0., zero_init_last=True):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -364,10 +378,10 @@ class ResNetV2(nn.Module):
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
self.init_weights(zero_init_last_bn=zero_init_last_bn)
self.init_weights(zero_init_last=zero_init_last)
def init_weights(self, zero_init_last_bn=True):
named_apply(partial(_init_weights, zero_init_last_bn=zero_init_last_bn), self)
def init_weights(self, zero_init_last=True):
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
@ -393,7 +407,7 @@ class ResNetV2(nn.Module):
return x
def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
def _init_weights(module: nn.Module, name: str = '', zero_init_last=True):
if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)):
nn.init.normal_(module.weight, mean=0.0, std=0.01)
nn.init.zeros_(module.bias)
@ -404,8 +418,8 @@ def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif zero_init_last_bn and hasattr(module, 'zero_init_last_bn'):
module.zero_init_last_bn()
elif zero_init_last and hasattr(module, 'zero_init_last'):
module.zero_init_last()
@torch.no_grad()
@ -570,12 +584,68 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
def resnetv2_50(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, **kwargs)
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
@register_model
def resnetv2_50d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
@register_model
def resnetv2_50t(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50t', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='tiered', avg_down=True, **kwargs)
@register_model
def resnetv2_101(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101', pretrained=pretrained,
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
@register_model
def resnetv2_101d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101d', pretrained=pretrained,
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
@register_model
def resnetv2_152(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152', pretrained=pretrained,
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
@register_model
def resnetv2_152d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152d', pretrained=pretrained,
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
# @register_model
# def resnetv2_50ebd(pretrained=False, **kwargs):
# # FIXME for testing w/ TPU + PyTorch XLA
# return _create_resnetv2(
# 'resnetv2_50d', pretrained=pretrained,
# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d,
# stem_type='deep', avg_down=True, **kwargs)
#
#
# @register_model
# def resnetv2_50esd(pretrained=False, **kwargs):
# # FIXME for testing w/ TPU + PyTorch XLA
# return _create_resnetv2(
# 'resnetv2_50d', pretrained=pretrained,
# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d,
# stem_type='deep', avg_down=True, **kwargs)

Loading…
Cancel
Save