diff --git a/timm/models/__init__.py b/timm/models/__init__.py index b4fe1dea..47cffac8 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -1,30 +1,31 @@ -from .inception_v4 import * -from .inception_resnet_v2 import * from .densenet import * -from .resnet import * +from .dla import * from .dpn import * -from .senet import * -from .xception import * -from .nasnet import * -from .pnasnet import * -from .selecsls import * from .efficientnet import * -from .mobilenetv3 import * -from .inception_v3 import * from .gluon_resnet import * from .gluon_xception import * -from .res2net import * -from .dla import * from .hrnet import * +from .inception_resnet_v2 import * +from .inception_v3 import * +from .inception_v4 import * +from .mobilenetv3 import * +from .nasnet import * +from .pnasnet import * +from .regnet import * +from .res2net import * +from .resnest import * +from .resnet import * +from .selecsls import * +from .senet import * from .sknet import * from .tresnet import * -from .resnest import * -from .regnet import * from .vovnet import * +from .xception import * +from .xception_aligned import * -from .registry import * from .factory import create_model from .helpers import load_checkpoint, resume_checkpoint from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit +from .registry import * diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 3baad3bf..48c254d4 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -74,6 +74,9 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') + if filter_fn is not None: + state_dict = filter_fn(state_dict) + if in_chans == 1: conv1_name = cfg['first_conv'] logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name) @@ -95,9 +98,6 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non del state_dict[classifier_name + '.bias'] strict = False - if filter_fn is not None: - state_dict = filter_fn(state_dict) - model.load_state_dict(state_dict, strict=strict) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 951648c7..85c00486 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -223,11 +223,12 @@ class Block8(nn.Module): class InceptionResnetV2(nn.Module): - def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'): + def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'): super(InceptionResnetV2, self).__init__() self.drop_rate = drop_rate self.num_classes = num_classes self.num_features = 1536 + assert output_stride == 32 self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) @@ -340,16 +341,16 @@ class InceptionResnetV2(nn.Module): def _inception_resnet_v2(variant, pretrained=False, **kwargs): - load_strict, features, out_indices = True, False, None + features, out_indices = False, None if kwargs.pop('features_only', False): - load_strict, features, out_indices = False, True, kwargs.pop('out_indices', (0, 1, 2, 3, 4)) - kwargs.pop('num_classes', 0) + features = True + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) model = InceptionResnetV2(**kwargs) model.default_cfg = default_cfgs[variant] if pretrained: load_pretrained( model, - num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict) + num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features) if features: model = FeatureNet(model, out_indices) return model diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index bc802ba2..24e3f2a8 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -400,14 +400,15 @@ class ReductionCell1(nn.Module): class NASNetALarge(nn.Module): """NASNetALarge (6 @ 4032) """ - def __init__(self, num_classes=1000, in_chans=1, stem_size=96, num_features=4032, channel_multiplier=2, - drop_rate=0., global_pool='avg', pad_type='same'): + def __init__(self, num_classes=1000, in_chans=1, stem_size=96, channel_multiplier=2, + num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'): super(NASNetALarge, self).__init__() self.num_classes = num_classes self.stem_size = stem_size self.num_features = num_features self.channel_multiplier = channel_multiplier self.drop_rate = drop_rate + assert output_stride == 32 channels = self.num_features // 24 # 24 is default value for the architecture diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index dc0f078d..fb2eb0dd 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -236,11 +236,12 @@ class Cell(CellBase): class PNASNet5Large(nn.Module): - def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg', padding=''): + def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0.5, global_pool='avg', padding=''): super(PNASNet5Large, self).__init__() self.num_classes = num_classes - self.num_features = 4320 self.drop_rate = drop_rate + self.num_features = 4320 + assert output_stride == 32 self.conv_0 = ConvBnAct( in_chans, 96, kernel_size=3, stride=2, padding=0, diff --git a/timm/models/regnet.py b/timm/models/regnet.py index c8961bc2..d934c2a5 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -12,15 +12,15 @@ Weights from original impl have been modified * remap names to match the ones here """ -import torch +import numpy as np import torch.nn as nn import torch.nn.functional as F -import numpy as np -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .features import FeatureNet from .helpers import load_pretrained from .layers import SelectAdaptivePool2d, AvgPool2dSame, ConvBnAct, SEModule -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .registry import register_model def _mcfg(**kwargs): @@ -128,18 +128,17 @@ class Bottleneck(nn.Module): after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. """ - def __init__(self, in_chs, out_chs, stride=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25, - dilation=1, first_dilation=None, downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - aa_layer=None, drop_block=None, drop_path=None): + def __init__(self, in_chs, out_chs, stride=1, dilation=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25, + downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, + drop_block=None, drop_path=None): super(Bottleneck, self).__init__() bottleneck_chs = int(round(out_chs * bottleneck_ratio)) groups = bottleneck_chs // group_width - first_dilation = first_dilation or dilation cargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) self.conv1 = ConvBnAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) self.conv2 = ConvBnAct( - bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=first_dilation, + bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation, groups=groups, **cargs) if se_ratio: se_channels = int(round(in_chs * se_ratio)) @@ -172,16 +171,16 @@ class Bottleneck(nn.Module): def downsample_conv( - in_chs, out_chs, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None): norm_layer = norm_layer or nn.BatchNorm2d kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size - first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 + dilation = dilation if kernel_size > 1 else 1 return ConvBnAct( - in_chs, out_chs, kernel_size, stride=stride, dilation=first_dilation, norm_layer=norm_layer, act_layer=None) + in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, act_layer=None) def downsample_avg( - in_chs, out_chs, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None): """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" norm_layer = norm_layer or nn.BatchNorm2d avg_stride = stride if dilation == 1 else 1 @@ -196,21 +195,24 @@ def downsample_avg( class RegStage(nn.Module): """Stage (sequence of blocks w/ the same output shape).""" - def __init__(self, in_chs, out_chs, stride, depth, block_fn, bottle_ratio, group_width, se_ratio): + def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width, + block_fn=Bottleneck, se_ratio=0.): super(RegStage, self).__init__() block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args + first_dilation = 1 if dilation in (1, 2) else 2 for i in range(depth): block_stride = stride if i == 0 else 1 block_in_chs = in_chs if i == 0 else out_chs + block_dilation = first_dilation if i == 0 else dilation if (block_in_chs != out_chs) or (block_stride != 1): - proj_block = downsample_conv(block_in_chs, out_chs, 1, stride) + proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation) else: proj_block = None name = "b{}".format(i + 1) self.add_module( name, block_fn( - block_in_chs, out_chs, block_stride, bottle_ratio, group_width, se_ratio, + block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio, downsample=proj_block, **block_kwargs) ) @@ -247,26 +249,30 @@ class RegNet(nn.Module): Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py """ - def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., + def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., zero_init_last_bn=True): super().__init__() # TODO add drop block, drop path, anti-aliasing, custom bn/act args self.num_classes = num_classes self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) # Construct the stem stem_width = cfg['stem_width'] self.stem = ConvBnAct(in_chans, stem_width, 3, stride=2) - + self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')] + # Construct the stages - block_fn = Bottleneck prev_width = stem_width - stage_params = self._get_stage_params(cfg) + curr_stride = 2 + stage_params = self._get_stage_params(cfg, output_stride=output_stride) se_ratio = cfg['se_ratio'] - for i, (d, w, s, br, gw) in enumerate(stage_params): - self.add_module( - "s{}".format(i + 1), RegStage(prev_width, w, s, d, block_fn, br, gw, se_ratio)) - prev_width = w + for i, stage_args in enumerate(stage_params): + stage_name = "s{}".format(i + 1) + self.add_module(stage_name, RegStage(prev_width, **stage_args, se_ratio=se_ratio)) + prev_width = stage_args['out_chs'] + curr_stride *= stage_args['stride'] + self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)] # Construct the head self.num_features = prev_width @@ -287,7 +293,7 @@ class RegNet(nn.Module): if hasattr(m, 'zero_init_last_bn'): m.zero_init_last_bn() - def _get_stage_params(self, cfg, stride=2): + def _get_stage_params(self, cfg, default_stride=2, output_stride=32): # Generate RegNet ws per block w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth'] widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d) @@ -298,12 +304,26 @@ class RegNet(nn.Module): # Use the same group width, bottleneck mult and stride for each stage stage_groups = [cfg['group_w'] for _ in range(num_stages)] stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)] - stage_strides = [stride for _ in range(num_stages)] - # FIXME add dilation / output_stride support + stage_strides = [] + stage_dilations = [] + total_stride = 2 + dilation = 1 + for _ in range(num_stages): + if total_stride >= output_stride: + dilation *= default_stride + stride = 1 + else: + stride = default_stride + total_stride *= stride + stage_strides.append(stride) + stage_dilations.append(dilation) # Adjust the compatibility of ws and gws stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) - stage_params = list(zip(stage_depths, stage_widths, stage_strides, stage_bottle_ratios, stage_groups)) + param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width'] + stage_params = [ + dict(zip(param_names, params)) for params in + zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups)] return stage_params def get_classifier(self): @@ -324,20 +344,20 @@ class RegNet(nn.Module): def _regnet(variant, pretrained, **kwargs): - load_strict = True - model_class = RegNet + features = False + out_indices = None if kwargs.pop('features_only', False): - assert False, 'Not Implemented' # TODO - load_strict = False - kwargs.pop('num_classes', 0) + features = True + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) model_cfg = model_cfgs[variant] - default_cfg = default_cfgs[variant] - model = model_class(model_cfg, **kwargs) - model.default_cfg = default_cfg + model = RegNet(model_cfg, **kwargs) + model.default_cfg = default_cfgs[variant] if pretrained: load_pretrained( - model, default_cfg, - num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict) + model, + num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features) + if features: + model = FeatureNet(model, out_indices=out_indices) return model diff --git a/timm/models/resnet.py b/timm/models/resnet.py index dfa864f2..1b87ed08 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -33,6 +33,7 @@ def _cfg(url='', **kwargs): default_cfgs = { + # ResNet and Wide ResNet 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 'resnet34': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), @@ -54,6 +55,8 @@ default_cfgs = { 'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), 'wide_resnet50_2': _cfg(url='https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth'), 'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'), + + # ResNeXt 'resnext50_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d_ra-d733960d.pth', interpolation='bicubic'), @@ -64,10 +67,17 @@ default_cfgs = { 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), 'resnext101_64x4d': _cfg(url=''), 'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'), + + # ResNeXt models - Weakly Supervised Pretraining on Instagram Hashtags + # from https://github.com/facebookresearch/WSL-Images + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. 'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'), 'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'), 'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'), 'ig_resnext101_32x48d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth'), + + # Semi-Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. 'ssl_resnet18': _cfg( url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth'), 'ssl_resnet50': _cfg( @@ -80,6 +90,9 @@ default_cfgs = { url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth'), 'ssl_resnext101_32x16d': _cfg( url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth'), + + # Semi-Weakly Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. 'swsl_resnet18': _cfg( url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth'), 'swsl_resnet50': _cfg( @@ -92,6 +105,31 @@ default_cfgs = { url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth'), 'swsl_resnext101_32x16d': _cfg( url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'), + + # Squeeze-Excitation ResNets, to eventually replace the models in senet.py + 'seresnet18': _cfg( + url='', + interpolation='bicubic'), + 'seresnet34': _cfg( + url='', + interpolation='bicubic'), + 'seresnet50': _cfg( + url='', + interpolation='bicubic'), + 'seresnet50tn': _cfg( + url='', + interpolation='bicubic'), + 'seresnet101': _cfg( + url='', + interpolation='bicubic'), + 'seresnet152': _cfg( + url='', + interpolation='bicubic'), + + # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py + 'seresnext26_32x4d': _cfg( + url='', + interpolation='bicubic'), 'seresnext26d_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth', interpolation='bicubic'), @@ -101,9 +139,19 @@ default_cfgs = { 'seresnext26tn_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', interpolation='bicubic'), - 'ecaresnext26tn_32x4d': _cfg( + 'seresnext50_32x4d': _cfg( + interpolation='bicubic'), + 'seresnext101_32x4d': _cfg( url='', interpolation='bicubic'), + 'seresnext101_32x8d': _cfg( + url='', + interpolation='bicubic'), + 'senet154': _cfg( + url='', + interpolation='bicubic'), + + # Efficient Channel Attention ResNets 'ecaresnet18': _cfg(), 'ecaresnet50': _cfg(), 'ecaresnetlight': _cfg( @@ -121,6 +169,16 @@ default_cfgs = { 'ecaresnet101d_pruned': _cfg( url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', interpolation='bicubic'), + + # Efficient Channel Attention ResNeXts + 'ecaresnext26tn_32x4d': _cfg( + url='', + interpolation='bicubic'), + 'ecaresnext50_32x4d': _cfg( + url='', + interpolation='bicubic'), + + # ResNets with anti-aliasing blur pool 'resnetblur18': _cfg( interpolation='bicubic'), 'resnetblur50': _cfg( @@ -278,6 +336,14 @@ class Bottleneck(nn.Module): return x +def setup_drop_block(drop_block_rate=0.): + return [ + None, + None, + DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None, + DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None] + + def downsample_conv( in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): norm_layer = norm_layer or nn.BatchNorm2d @@ -386,6 +452,7 @@ class ResNet(nn.Module): act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): block_args = block_args or dict() + assert output_stride in (8, 16, 32) self.num_classes = num_classes deep_stem = 'deep' in stem_type self.inplanes = stem_width * 2 if deep_stem else 64 @@ -393,7 +460,6 @@ class ResNet(nn.Module): self.base_width = base_width self.drop_rate = drop_rate self.expansion = block.expansion - self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')] super(ResNet, self).__init__() # Stem @@ -414,6 +480,8 @@ class ResNet(nn.Module): self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.act1 = act_layer(inplace=True) + self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')] + # Stem Pooling if aa_layer is not None: self.maxpool = nn.Sequential(*[ @@ -424,32 +492,26 @@ class ResNet(nn.Module): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks - channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4 - if output_stride == 16: - strides[3] = 1 - dilations[3] = 2 - elif output_stride == 8: - strides[2:4] = [1, 1] - dilations[2:4] = [2, 4] - else: - assert output_stride == 32 + channels = [64, 128, 256, 512] dp = DropPath(drop_path_rate) if drop_path_rate else None - db = [ - None, None, - DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None, - DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None] - layer_args = list(zip(channels, layers, strides, dilations)) + db = setup_drop_block(drop_block_rate) layer_kwargs = dict( reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) - current_stride = 4 + total_stride = 4 + dilation = 1 for i in range(4): layer_name = f'layer{i + 1}' + stride = 2 if i > 0 else 1 + if total_stride >= output_stride: + dilation *= stride + stride = 1 + else: + total_stride *= stride self.add_module(layer_name, self._make_layer( - block, *layer_args[i], drop_block=db[i], **layer_kwargs)) - current_stride *= strides[i] + block, channels[i], layers[i], stride, dilation, drop_block=db[i], **layer_kwargs)) self.feature_info.append(dict( - num_chs=self.inplanes, reduction=current_stride, module=layer_name)) + num_chs=self.inplanes, reduction=total_stride, module=layer_name)) # Head (Pooling and Classifier) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -872,55 +934,6 @@ def swsl_resnext101_32x16d(pretrained=True, **kwargs): return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args) -@register_model -def seresnext26d_32x4d(pretrained=False, **kwargs): - """Constructs a SE-ResNeXt-26-D model. - This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for - combination of deep stem and avg_pool in downsample. - """ - model_args = dict( - block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, - stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnext26d_32x4d', pretrained, **model_args) - - -@register_model -def seresnext26t_32x4d(pretrained=False, **kwargs): - """Constructs a SE-ResNet-26-T model. - This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels - in the deep stem. - """ - model_args = dict( - block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, - stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnext26t_32x4d', pretrained, **model_args) - - -@register_model -def seresnext26tn_32x4d(pretrained=False, **kwargs): - """Constructs a SE-ResNeXt-26-TN model. - This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels - in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. - """ - model_args = dict( - block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, - stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnext26tn_32x4d', pretrained, **model_args) - - -@register_model -def ecaresnext26tn_32x4d(pretrained=False, **kwargs): - """Constructs an ECA-ResNeXt-26-TN model. - This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels - in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. - this model replaces SE module with the ECA module - """ - model_args = dict( - block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, - stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnext26tn_32x4d', pretrained, **model_args) - - @register_model def ecaresnet18(pretrained=False, **kwargs): """ Constructs an ECA-ResNet-18 model. @@ -989,6 +1002,19 @@ def ecaresnet101d_pruned(pretrained=False, **kwargs): return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args) +@register_model +def ecaresnext26tn_32x4d(pretrained=False, **kwargs): + """Constructs an ECA-ResNeXt-26-TN model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. + this model replaces SE module with the ECA module + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnext26tn_32x4d', pretrained, **model_args) + + @register_model def resnetblur18(pretrained=False, **kwargs): """Constructs a ResNet-18 model with blur anti-aliasing @@ -1003,3 +1029,123 @@ def resnetblur50(pretrained=False, **kwargs): """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs) return _create_resnet('resnetblur50', pretrained, **model_args) + + +@register_model +def seresnet18(pretrained=False, **kwargs): + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet18', pretrained, **model_args) + + +@register_model +def seresnet34(pretrained=False, **kwargs): + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet34', pretrained, **model_args) + + +@register_model +def seresnet50(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet50', pretrained, **model_args) + + +@register_model +def seresnet50tn(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet50tn', pretrained, **model_args) + + +@register_model +def seresnet101(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet101', pretrained, **model_args) + + +@register_model +def seresnet152(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet152', pretrained, **model_args) + + +@register_model +def seresnext26_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26_32x4d', pretrained, **model_args) + + +@register_model +def seresnext26d_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNeXt-26-D model.` + This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for + combination of deep stem and avg_pool in downsample. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26d_32x4d', pretrained, **model_args) + + +@register_model +def seresnext26t_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNet-26-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels + in the deep stem. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26t_32x4d', pretrained, **model_args) + + +@register_model +def seresnext26tn_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNeXt-26-TN model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26tn_32x4d', pretrained, **model_args) + + +@register_model +def seresnext50_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def seresnext101_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext101_32x4d', pretrained, **model_args) + + +@register_model +def seresnext101_32x8d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext101_32x8d', pretrained, **model_args) + + +@register_model +def senet154(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', + down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('senet154', pretrained, **model_args) + + +@register_model +def eseresnet50(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='ese'), **kwargs) + return _create_resnet('seresnet50', pretrained, **model_args) diff --git a/timm/models/senet.py b/timm/models/senet.py index 8594d14d..e4fca920 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -7,6 +7,9 @@ Original model: https://github.com/hujie-frank/SENet ResNet code gently borrowed from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + +FIXME I'm deprecating this model and moving them to ResNet as I don't want to maintain duplicate +support for extras like dilation, switchable BN/activations, feature extraction, etc that don't exist here. """ import math from collections import OrderedDict @@ -397,7 +400,7 @@ class SENet(nn.Module): @register_model -def seresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def legacy_seresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnet18'] model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16, inplanes=64, input_3x3=False, @@ -410,7 +413,7 @@ def seresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def seresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def legacy_seresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnet34'] model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16, inplanes=64, input_3x3=False, @@ -423,7 +426,7 @@ def seresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def seresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def legacy_seresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnet50'] model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, inplanes=64, input_3x3=False, @@ -436,7 +439,7 @@ def seresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def seresnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def legacy_seresnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnet101'] model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, inplanes=64, input_3x3=False, @@ -449,7 +452,7 @@ def seresnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def seresnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def legacy_seresnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnet152'] model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, inplanes=64, input_3x3=False, @@ -462,7 +465,7 @@ def seresnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def legacy_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['senet154'] model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, num_classes=num_classes, in_chans=in_chans, **kwargs) @@ -473,7 +476,7 @@ def senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def seresnext26_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def legacy_seresnext26_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnext26_32x4d'] model = SENet(SEResNeXtBottleneck, [2, 2, 2, 2], groups=32, reduction=16, inplanes=64, input_3x3=False, @@ -486,7 +489,7 @@ def seresnext26_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def legacy_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnext50_32x4d'] model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, inplanes=64, input_3x3=False, @@ -499,7 +502,7 @@ def seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def legacy_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnext101_32x4d'] model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, inplanes=64, input_3x3=False, diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 0793120e..b1520149 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -275,13 +275,14 @@ class ClassifierHead(nn.Module): class VovNet(nn.Module): def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, - norm_layer=BatchNormAct2d): + output_stride=32, norm_layer=BatchNormAct2d): """ VovNet (v2) """ super(VovNet, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate assert stem_stride in (4, 2) + assert output_stride == 32 # FIXME support dilation stem_chs = cfg["stem_chs"] stage_conv_chs = cfg["stage_conv_chs"] @@ -349,7 +350,6 @@ def _vovnet(variant, pretrained=False, **kwargs): out_indices = None if kwargs.pop('features_only', False): features = True - kwargs.pop('num_classes', 0) out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) model_cfg = model_cfgs[variant] model = VovNet(model_cfg, **kwargs) @@ -412,10 +412,11 @@ def eca_vovnet39b(pretrained=False, **kwargs): @register_model def ese_vovnet39b_evos(pretrained=False, **kwargs): - def norm_act_fn(num_features, **kwargs): - return create_norm_act('EvoNormSample', num_features, jit=False, **kwargs) + def norm_act_fn(num_features, **nkwargs): + return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs) return _vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) + @register_model def ese_vovnet99b_iabn(pretrained=False, **kwargs): norm_layer = get_norm_act_layer('iabn') diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py new file mode 100644 index 00000000..c2006173 --- /dev/null +++ b/timm/models/xception_aligned.py @@ -0,0 +1,278 @@ +"""Pytorch impl of Aligned Xception + +This is a correct impl of Aligned Xception (Deeplab) models compatible with TF definition. + +Hacked together by Ross Wightman +""" +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .features import FeatureNet +from .helpers import load_pretrained +from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d +from .registry import register_model + +__all__ = ['XceptionAligned'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10), + 'crop_pct': 0.903, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + xception41=_cfg(url=''), + xception65=_cfg(url=''), + xception71=_cfg(url=''), +) + + +class SeparableConv2d(nn.Module): + def __init__( + self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='', + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(SeparableConv2d, self).__init__() + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + self.kernel_size = kernel_size + self.dilation = dilation + + # depthwise convolution + self.conv_dw = create_conv2d( + inplanes, inplanes, kernel_size, stride=stride, + padding=padding, dilation=dilation, depthwise=True) + self.bn_dw = norm_layer(inplanes, **norm_kwargs) + if act_layer is not None: + self.act_dw = act_layer(inplace=True) + else: + self.act_dw = None + + # pointwise convolution + self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1) + self.bn_pw = norm_layer(planes, **norm_kwargs) + if act_layer is not None: + self.act_pw = act_layer(inplace=True) + else: + self.act_pw = None + + def forward(self, x): + x = self.conv_dw(x) + x = self.bn_dw(x) + if self.act_dw is not None: + x = self.act_dw(x) + x = self.conv_pw(x) + x = self.bn_pw(x) + if self.act_pw is not None: + x = self.act_pw(x) + return x + + +class XceptionModule(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, pad_type='', + start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None, norm_kwargs=None): + super(XceptionModule, self).__init__() + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + if isinstance(out_chs, (list, tuple)): + assert len(out_chs) == 3 + else: + out_chs = (out_chs,) * 3 + self.in_channels = in_chs + self.out_channels = out_chs[-1] + self.no_skip = no_skip + if not no_skip and (self.out_channels != self.in_channels or stride != 1): + self.shortcut = ConvBnAct( + in_chs, self.out_channels, 1, stride=stride, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, act_layer=None) + else: + self.shortcut = None + + separable_act_layer = None if start_with_relu else act_layer + self.stack = nn.Sequential() + for i in range(3): + if start_with_relu: + self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0)) + self.stack.add_module(f'conv{i + 1}', SeparableConv2d( + in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, + act_layer=separable_act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs)) + in_chs = out_chs[i] + + def forward(self, x): + skip = x + x = self.stack(x) + if self.shortcut is not None: + skip = self.shortcut(skip) + if not self.no_skip: + x = x + skip + return x + + +class ClassifierHead(nn.Module): + """Head.""" + + def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): + super(ClassifierHead, self).__init__() + self.drop_rate = drop_rate + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + if num_classes > 0: + self.fc = nn.Linear(in_chs, num_classes, bias=True) + else: + self.fc = nn.Identity() + + def forward(self, x): + x = self.global_pool(x).flatten(1) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + return x + + +class XceptionAligned(nn.Module): + """Modified Aligned Xception + """ + + def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_rate=0., global_pool='avg'): + super(XceptionAligned, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + + xtra_args = dict(act_layer=act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.stem = nn.Sequential(*[ + ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **xtra_args), + ConvBnAct(32, 64, kernel_size=3, stride=1, **xtra_args) + ]) + curr_dilation = 1 + curr_stride = 2 + self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem.1')] + + self.blocks = nn.Sequential() + for i, b in enumerate(block_cfg): + feature_extract = False + b['dilation'] = curr_dilation + if b['stride'] > 1: + feature_extract = True + next_stride = curr_stride * b['stride'] + if next_stride > output_stride: + curr_dilation *= b['stride'] + b['stride'] = 1 + else: + curr_stride = next_stride + self.blocks.add_module(str(i), XceptionModule(**b, **xtra_args)) + self.num_features = self.blocks[-1].out_channels + if feature_extract: + self.feature_info += [dict( + num_chs=self.num_features, reduction=curr_stride, module=f'blocks.{i}.stack.act2')] + + self.feature_info += [dict( + num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] + + self.head = ClassifierHead( + in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _xception(variant, pretrained=False, **kwargs): + features = False + out_indices = None + if kwargs.pop('features_only', False): + features = True + kwargs.pop('num_classes', 0) + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) + model = XceptionAligned(**kwargs) + model.default_cfg = default_cfgs[variant] + if pretrained: + load_pretrained( + model, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3), + strict=not features) + if features: + model = FeatureNet(model, out_indices) + return model + + + +@register_model +def xception41(pretrained=False, **kwargs): + """ Modified Aligned Xception-41 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 8), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) + return _xception('xception41', pretrained=pretrained, **model_args) + + +@register_model +def xception65(pretrained=False, **kwargs): + """ Modified Aligned Xception-65 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) + return _xception('xception65', pretrained=pretrained, **model_args) + + + +@register_model +def xception71(pretrained=False, **kwargs): + """ Modified Aligned Xception-71 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=1), + dict(in_chs=256, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=1), + dict(in_chs=728, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) + return _xception('xception71', pretrained=pretrained, **model_args)