From e8045e712f0c5e977558f4457813858e9b70c613 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 28 Jun 2021 10:52:45 -0700 Subject: [PATCH] Fix BatchNorm for ResNetV2 non GN models, add more ResNetV2 model defs for future experimentation, fix zero_init of last residual for pre-act. --- timm/models/resnetv2.py | 112 ++++++++++++++++++++++++++++++++-------- 1 file changed, 91 insertions(+), 21 deletions(-) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index b96d7742..4fd3b823 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -38,7 +38,8 @@ from functools import partial from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg, named_apply, adapt_input_conv from .registry import register_model -from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d +from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\ + ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d def _cfg(url='', **kwargs): @@ -107,6 +108,16 @@ default_cfgs = { interpolation='bicubic'), 'resnetv2_50d': _cfg( interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50t': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_101': _cfg( + interpolation='bicubic'), + 'resnetv2_101d': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_152': _cfg( + interpolation='bicubic'), + 'resnetv2_152d': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), } @@ -152,8 +163,8 @@ class PreActBottleneck(nn.Module): self.conv3 = conv_layer(mid_chs, out_chs, 1) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() - def zero_init_last_bn(self): - nn.init.zeros_(self.norm3.weight) + def zero_init_last(self): + nn.init.zeros_(self.conv3.weight) def forward(self, x): x_preact = self.norm1(x) @@ -201,7 +212,7 @@ class Bottleneck(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.act3 = act_layer(inplace=True) - def zero_init_last_bn(self): + def zero_init_last(self): nn.init.zeros_(self.norm3.weight) def forward(self, x): @@ -284,17 +295,20 @@ def create_resnetv2_stem( in_chs, out_chs=64, stem_type='', preact=True, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)): stem = OrderedDict() - assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same') + assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered') # NOTE conv padding mode can be changed by overriding the conv_layer def - if 'deep' in stem_type: + if any([s in stem_type for s in ('deep', 'tiered')]): # A 3 deep 3x3 conv stack as in ResNet V1D models - mid_chs = out_chs // 2 - stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) - stem['norm1'] = norm_layer(mid_chs) - stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) - stem['norm2'] = norm_layer(mid_chs) - stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1) + if 'tiered' in stem_type: + stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py + else: + stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets + stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2) + stem['norm1'] = norm_layer(stem_chs[0]) + stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1) + stem['norm2'] = norm_layer(stem_chs[1]) + stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1) if not preact: stem['norm3'] = norm_layer(out_chs) else: @@ -326,7 +340,7 @@ class ResNetV2(nn.Module): num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), - drop_rate=0., drop_path_rate=0., zero_init_last_bn=True): + drop_rate=0., drop_path_rate=0., zero_init_last=True): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -364,10 +378,10 @@ class ResNetV2(nn.Module): self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) - self.init_weights(zero_init_last_bn=zero_init_last_bn) + self.init_weights(zero_init_last=zero_init_last) - def init_weights(self, zero_init_last_bn=True): - named_apply(partial(_init_weights, zero_init_last_bn=zero_init_last_bn), self) + def init_weights(self, zero_init_last=True): + named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @torch.jit.ignore() def load_pretrained(self, checkpoint_path, prefix='resnet/'): @@ -393,7 +407,7 @@ class ResNetV2(nn.Module): return x -def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True): +def _init_weights(module: nn.Module, name: str = '', zero_init_last=True): if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)): nn.init.normal_(module.weight, mean=0.0, std=0.01) nn.init.zeros_(module.bias) @@ -404,8 +418,8 @@ def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True): elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) - elif zero_init_last_bn and hasattr(module, 'zero_init_last_bn'): - module.zero_init_last_bn() + elif zero_init_last and hasattr(module, 'zero_init_last'): + module.zero_init_last() @torch.no_grad() @@ -570,12 +584,68 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs): def resnetv2_50(pretrained=False, **kwargs): return _create_resnetv2( 'resnetv2_50', pretrained=pretrained, - layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, **kwargs) + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) @register_model def resnetv2_50d(pretrained=False, **kwargs): return _create_resnetv2( 'resnetv2_50d', pretrained=pretrained, - layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_50t(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50t', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='tiered', avg_down=True, **kwargs) + + +@register_model +def resnetv2_101(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101', pretrained=pretrained, + layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) + + +@register_model +def resnetv2_101d(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101d', pretrained=pretrained, + layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_152(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152', pretrained=pretrained, + layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) + + +@register_model +def resnetv2_152d(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152d', pretrained=pretrained, + layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='deep', avg_down=True, **kwargs) + + +# @register_model +# def resnetv2_50ebd(pretrained=False, **kwargs): +# # FIXME for testing w/ TPU + PyTorch XLA +# return _create_resnetv2( +# 'resnetv2_50d', pretrained=pretrained, +# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d, +# stem_type='deep', avg_down=True, **kwargs) +# +# +# @register_model +# def resnetv2_50esd(pretrained=False, **kwargs): +# # FIXME for testing w/ TPU + PyTorch XLA +# return _create_resnetv2( +# 'resnetv2_50d', pretrained=pretrained, +# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d, +# stem_type='deep', avg_down=True, **kwargs)