From 50658b9a673bb3d844d06dfa039de2a9e4ddbae8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 18 May 2020 00:08:52 -0700 Subject: [PATCH 1/7] Add RegNet models and weights --- timm/models/__init__.py | 1 + timm/models/layers/se.py | 4 +- timm/models/regnet.py | 485 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 488 insertions(+), 2 deletions(-) create mode 100644 timm/models/regnet.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index d421ad45..06d26fb3 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -19,6 +19,7 @@ from .hrnet import * from .sknet import * from .tresnet import * from .resnest import * +from .regnet import * from .registry import * from .factory import create_model diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py index de87ccf5..6bb4723e 100644 --- a/timm/models/layers/se.py +++ b/timm/models/layers/se.py @@ -3,10 +3,10 @@ from torch import nn as nn class SEModule(nn.Module): - def __init__(self, channels, reduction=16, act_layer=nn.ReLU): + def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) - reduction_channels = max(channels // reduction, 8) + reduction_channels = reduction_channels or max(channels // reduction, min_channels) self.fc1 = nn.Conv2d( channels, reduction_channels, kernel_size=1, padding=0, bias=True) self.act = act_layer(inplace=True) diff --git a/timm/models/regnet.py b/timm/models/regnet.py new file mode 100644 index 00000000..65ba2cc6 --- /dev/null +++ b/timm/models/regnet.py @@ -0,0 +1,485 @@ +"""RegNet + +Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678 +Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py + +Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here) +and cleaned up with more descriptive variable names. + +Weights from original impl have been modified +* first layer from BGR -> RGB as most PyTorch models are +* removed training specific dict entries from checkpoints and keep model state_dict only +* remap names to match the ones here + +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from .registry import register_model +from .helpers import load_pretrained +from .layers import SelectAdaptivePool2d, AvgPool2dSame, ConvBnAct, SEModule +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def _mcfg(**kwargs): + cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32) + cfg.update(**kwargs) + return cfg + + +# Model FLOPS = three trailing digits * 10^8 +model_cfgs = dict( + x_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13), + x_004=_mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22), + x_006=_mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16), + x_008=_mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16), + x_016=_mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18), + x_032=_mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25), + x_040=_mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23), + x_064=_mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17), + x_080=_mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23), + x_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19), + x_160=_mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22), + x_320=_mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23), + y_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25), + y_004=_mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25), + y_006=_mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25), + y_008=_mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25), + y_016=_mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25), + y_032=_mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25), + y_040=_mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25), + y_064=_mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25), + y_080=_mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25), + y_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25), + y_160=_mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25), + y_320=_mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25), +) + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + } + + +default_cfgs = dict( + x_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth'), + x_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth'), + x_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth'), + x_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth'), + x_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth'), + x_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth'), + x_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth'), + x_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth'), + x_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth'), + x_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'), + x_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'), + x_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'), + y_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'), + y_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'), + y_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'), + y_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'), + y_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'), + y_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_032-62b47782.pth'), + y_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'), + y_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'), + y_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'), + y_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'), + y_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth'), + y_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), +) + + +def quantize_float(f, q): + """Converts a float to closest non-zero int divisible by q.""" + return int(round(f / q) * q) + + +def adjust_widths_groups_comp(widths, bottle_ratios, groups): + """Adjusts the compatibility of widths and groups.""" + bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)] + bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)] + widths = [int(w_bot / b) for w_bot, b in zip(bottleneck_widths, bottle_ratios)] + return widths, groups + + +def generate_regnet(width_slope, width_initial, width_mult, depth, q=8): + """Generates per block widths from RegNet parameters.""" + assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % q == 0 + widths_cont = np.arange(depth) * width_slope + width_initial + width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult)) + widths = width_initial * np.power(width_mult, width_exps) + widths = np.round(np.divide(widths, q)) * q + num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1 + widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist() + return widths, num_stages, max_stage, widths_cont + + +class Bottleneck(nn.Module): + """ RegNet Bottleneck + + This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from + after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. + """ + + def __init__(self, in_chs, out_chs, stride=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25, + dilation=1, first_dilation=None, downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + aa_layer=None, drop_block=None, drop_path=None): + super(Bottleneck, self).__init__() + bottleneck_chs = int(round(out_chs * bottleneck_ratio)) + groups = bottleneck_chs // group_width + first_dilation = first_dilation or dilation + + cargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) + self.conv1 = ConvBnAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) + self.conv2 = ConvBnAct( + bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=first_dilation, + groups=groups, **cargs) + if se_ratio: + se_channels = int(round(in_chs * se_ratio)) + self.se = SEModule(bottleneck_chs, reduction_channels=se_channels) + else: + self.se = None + cargs['act_layer'] = None + self.conv3 = ConvBnAct(bottleneck_chs, out_chs, kernel_size=1, **cargs) + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.se is not None: + x = self.se(x) + x = self.conv3(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act3(x) + return x + + +def downsample_conv( + in_chs, out_chs, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 + return ConvBnAct( + in_chs, out_chs, kernel_size, stride=stride, dilation=first_dilation, norm_layer=norm_layer, act_layer=None) + + +def downsample_avg( + in_chs, out_chs, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + pool = nn.Identity() + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + return nn.Sequential(*[ + pool, ConvBnAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, act_layer=None)]) + + +class RegStage(nn.Module): + """Stage (sequence of blocks w/ the same output shape).""" + + def __init__(self, in_chs, out_chs, stride, depth, block_fn, bottle_ratio, group_width, se_ratio): + super(RegStage, self).__init__() + block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args + for i in range(depth): + block_stride = stride if i == 0 else 1 + block_in_chs = in_chs if i == 0 else out_chs + if (block_in_chs != out_chs) or (block_stride != 1): + proj_block = downsample_conv(block_in_chs, out_chs, 1, stride) + else: + proj_block = None + + name = "b{}".format(i + 1) + self.add_module( + name, block_fn( + block_in_chs, out_chs, block_stride, bottle_ratio, group_width, se_ratio, + downsample=proj_block, **block_kwargs) + ) + + def forward(self, x): + for block in self.children(): + x = block(x) + return x + + +class ClassifierHead(nn.Module): + """Head.""" + + def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): + super(ClassifierHead, self).__init__() + self.drop_rate = drop_rate + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + if num_classes > 0: + self.fc = nn.Linear(in_chs, num_classes, bias=True) + else: + self.fc = nn.Identity() + + def forward(self, x): + x = self.global_pool(x).flatten(1) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + return x + + +class RegNet(nn.Module): + """RegNet model. + + Paper: https://arxiv.org/abs/2003.13678 + Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py + """ + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., + zero_init_last_bn=True): + super().__init__() + # TODO add drop block, drop path, anti-aliasing, custom bn/act args + self.num_classes = num_classes + self.drop_rate = drop_rate + + # Construct the stem + stem_width = cfg['stem_width'] + self.stem = ConvBnAct(in_chans, stem_width, 3, stride=2) + + # Construct the stages + block_fn = Bottleneck + prev_width = stem_width + stage_params = self._get_stage_params(cfg) + se_ratio = cfg['se_ratio'] + for i, (d, w, s, br, gw) in enumerate(stage_params): + self.add_module( + "s{}".format(i + 1), RegStage(prev_width, w, s, d, block_fn, br, gw, se_ratio)) + prev_width = w + + # Construct the head + self.num_features = prev_width + self.head = ClassifierHead( + in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + nn.init.zeros_(m.bias) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def _get_stage_params(self, cfg, stride=2): + # Generate RegNet ws per block + w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth'] + widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d) + + # Convert to per stage format + stage_widths, stage_depths = np.unique(widths, return_counts=True) + + # Use the same group width, bottleneck mult and stride for each stage + stage_groups = [cfg['group_w'] for _ in range(num_stages)] + stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)] + stage_strides = [stride for _ in range(num_stages)] + # FIXME add dilation / output_stride support + + # Adjust the compatibility of ws and gws + stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) + stage_params = list(zip(stage_depths, stage_widths, stage_strides, stage_bottle_ratios, stage_groups)) + return stage_params + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + for block in list(self.children())[:-1]: + x = block(x) + return x + + def forward(self, x): + for block in self.children(): + x = block(x) + return x + + +def _regnet(variant, pretrained, **kwargs): + load_strict = True + model_class = RegNet + if kwargs.pop('features_only', False): + assert False, 'Not Implemented' # TODO + load_strict = False + kwargs.pop('num_classes', 0) + model_cfg = model_cfgs[variant] + default_cfg = default_cfgs[variant] + model = model_class(model_cfg, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, default_cfg, + num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict) + return model + + +@register_model +def regnetx_002(pretrained=False, **kwargs): + """RegNetX-200MF""" + return _regnet('x_002', pretrained, **kwargs) + + +@register_model +def regnetx_004(pretrained=False, **kwargs): + """RegNetX-400MF""" + return _regnet('x_004', pretrained, **kwargs) + + +@register_model +def regnetx_006(pretrained=False, **kwargs): + """RegNetX-600MF""" + return _regnet('x_006', pretrained, **kwargs) + + +@register_model +def regnetx_008(pretrained=False, **kwargs): + """RegNetX-800MF""" + return _regnet('x_008', pretrained, **kwargs) + + +@register_model +def regnetx_016(pretrained=False, **kwargs): + """RegNetX-1.6GF""" + return _regnet('x_016', pretrained, **kwargs) + + +@register_model +def regnetx_032(pretrained=False, **kwargs): + """RegNetX-3.2GF""" + return _regnet('x_032', pretrained, **kwargs) + + +@register_model +def regnetx_040(pretrained=False, **kwargs): + """RegNetX-4.0GF""" + return _regnet('x_040', pretrained, **kwargs) + + +@register_model +def regnetx_064(pretrained=False, **kwargs): + """RegNetX-6.4GF""" + return _regnet('x_064', pretrained, **kwargs) + + +@register_model +def regnetx_080(pretrained=False, **kwargs): + """RegNetX-8.0GF""" + return _regnet('x_080', pretrained, **kwargs) + + +@register_model +def regnetx_120(pretrained=False, **kwargs): + """RegNetX-12GF""" + return _regnet('x_120', pretrained, **kwargs) + + +@register_model +def regnetx_160(pretrained=False, **kwargs): + """RegNetX-16GF""" + return _regnet('x_160', pretrained, **kwargs) + + +@register_model +def regnetx_320(pretrained=False, **kwargs): + """RegNetX-32GF""" + return _regnet('x_320', pretrained, **kwargs) + + +@register_model +def regnety_002(pretrained=False, **kwargs): + """RegNetY-200MF""" + return _regnet('y_002', pretrained, **kwargs) + + +@register_model +def regnety_004(pretrained=False, **kwargs): + """RegNetY-400MF""" + return _regnet('y_004', pretrained, **kwargs) + + +@register_model +def regnety_006(pretrained=False, **kwargs): + """RegNetY-600MF""" + return _regnet('y_006', pretrained, **kwargs) + + +@register_model +def regnety_008(pretrained=False, **kwargs): + """RegNetY-800MF""" + return _regnet('y_008', pretrained, **kwargs) + + +@register_model +def regnety_016(pretrained=False, **kwargs): + """RegNetY-1.6GF""" + return _regnet('y_016', pretrained, **kwargs) + + +@register_model +def regnety_032(pretrained=False, **kwargs): + """RegNetY-3.2GF""" + return _regnet('y_032', pretrained, **kwargs) + + +@register_model +def regnety_040(pretrained=False, **kwargs): + """RegNetY-4.0GF""" + return _regnet('y_040', pretrained, **kwargs) + + +@register_model +def regnety_064(pretrained=False, **kwargs): + """RegNetY-6.4GF""" + return _regnet('y_064', pretrained, **kwargs) + + +@register_model +def regnety_080(pretrained=False, **kwargs): + """RegNetY-8.0GF""" + return _regnet('y_080', pretrained, **kwargs) + + +@register_model +def regnety_120(pretrained=False, **kwargs): + """RegNetY-12GF""" + return _regnet('y_120', pretrained, **kwargs) + + +@register_model +def regnety_160(pretrained=False, **kwargs): + """RegNetY-16GF""" + return _regnet('y_160', pretrained, **kwargs) + + +@register_model +def regnety_320(pretrained=False, **kwargs): + """RegNetY-32GF""" + return _regnet('y_320', pretrained, **kwargs) From afb6bd066910e71fb7b2621ab098a3826ccc27b8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 21 May 2020 15:28:36 -0700 Subject: [PATCH 2/7] Add backward and default_cfg tests and fix a few issues found. Fix #153 --- tests/test_inference.py | 19 ---------- tests/test_models.py | 70 +++++++++++++++++++++++++++++++++++ timm/models/dla.py | 10 +++-- timm/models/gluon_xception.py | 2 +- timm/models/hrnet.py | 2 +- timm/models/inception_v3.py | 2 +- timm/models/mobilenetv3.py | 2 +- timm/models/nasnet.py | 4 +- timm/models/resnest.py | 9 +++-- timm/models/selecsls.py | 2 +- timm/models/tresnet.py | 8 ++-- timm/models/xception.py | 1 + 12 files changed, 95 insertions(+), 36 deletions(-) delete mode 100644 tests/test_inference.py create mode 100644 tests/test_models.py diff --git a/tests/test_inference.py b/tests/test_inference.py deleted file mode 100644 index 2490a0bc..00000000 --- a/tests/test_inference.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest -import torch - -from timm import list_models, create_model - - -@pytest.mark.timeout(300) -@pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*')) -@pytest.mark.parametrize('batch_size', [1]) -def test_model_forward(model_name, batch_size): - """Run a single forward pass with each model""" - model = create_model(model_name, pretrained=False) - model.eval() - - inputs = torch.randn((batch_size, *model.default_cfg['input_size'])) - outputs = model(inputs) - - assert outputs.shape[0] == batch_size - assert not torch.isnan(outputs).any(), 'Output included NaNs' diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..65a7ebb3 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,70 @@ +import pytest +import torch + +from timm import list_models, create_model + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward(model_name, batch_size): + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False) + model.eval() + + input_size = model.default_cfg['input_size'] + if any([x > 448 for x in input_size]): + # cap forward test at max res 448 * 448 to keep resource down + input_size = tuple([min(x, 448) for x in input_size]) + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters='dla*')) # DLA models have an issue TBD +@pytest.mark.parametrize('batch_size', [2]) +def test_model_backward(model_name, batch_size): + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False, num_classes=42) + num_params = sum([x.numel() for x in model.parameters()]) + model.eval() + + input_size = model.default_cfg['input_size'] + if any([x > 128 for x in input_size]): + # cap backward test at 128 * 128 to keep resource usage down + input_size = tuple([min(x, 128) for x in input_size]) + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + outputs.mean().backward() + num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) + + assert outputs.shape[-1] == 42 + assert num_params == num_grad, 'Some parameters are missing gradients' + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_default_cfgs(model_name, batch_size): + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False) + model.eval() + state_dict = model.state_dict() + cfg = model.default_cfg + + classifier = cfg['classifier'] + first_conv = cfg['first_conv'] + pool_size = cfg['pool_size'] + input_size = model.default_cfg['input_size'] + + if all([x <= 448 for x in input_size]): + # pool size only checked if default res <= 448 * 448 to keep resource down + input_size = tuple([min(x, 448) for x in input_size]) + outputs = model.forward_features(torch.randn((batch_size, *input_size))) + assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] + assert any([k.startswith(cfg['classifier']) for k in state_dict.keys()]), f'{classifier} not in model params' + assert any([k.startswith(cfg['first_conv']) for k in state_dict.keys()]), f'{first_conv} not in model params' diff --git a/timm/models/dla.py b/timm/models/dla.py index a9e81d16..94803e69 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -237,8 +237,11 @@ class DlaTree(nn.Module): def forward(self, x, residual=None, children=None): children = [] if children is None else children - bottom = self.downsample(x) if self.downsample else x - residual = self.project(bottom) if self.project else bottom + # FIXME the way downsample / project are used here and residual is passed to next level up + # the tree, the residual is overridden and some project weights are thus never used and + # have no gradients. This appears to be an issue with the original model / weights. + bottom = self.downsample(x) if self.downsample is not None else x + residual = self.project(bottom) if self.project is not None else bottom if self.level_root: children.append(bottom) x1 = self.tree1(x, residual) @@ -354,7 +357,8 @@ def dla60_res2next(pretrained=None, num_classes=1000, in_chans=3, **kwargs): @register_model def dla34(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-34 default_cfg = default_cfgs['dla34'] - model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic, **kwargs) + model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic, + num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 2fc8e699..a737b8f7 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -36,7 +36,7 @@ default_cfgs = { 'url': '', 'input_size': (3, 299, 299), 'crop_pct': 0.875, - 'pool_size': (10, 10), + 'pool_size': (5, 5), 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 06327c65..ac4824bb 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -34,7 +34,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv1', 'classifier': 'fc', + 'first_conv': 'conv1', 'classifier': 'classifier', **kwargs } diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index 0997e024..ffaab4f1 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -15,7 +15,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'conv1', 'classifier': 'fc', + 'first_conv': 'Conv2d_1a_3x3', 'classifier': 'fc', **kwargs } diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 86ca9f7a..9c0e863a 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -21,7 +21,7 @@ __all__ = ['MobileNetV3'] def _cfg(url='', **kwargs): return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'conv_stem', 'classifier': 'classifier', diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 8847b1de..511b006b 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -19,7 +19,7 @@ default_cfgs = { 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1001, - 'first_conv': 'conv_0.conv', + 'first_conv': 'conv0.conv', 'classifier': 'last_linear', }, } @@ -612,7 +612,7 @@ def nasnetalarge(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """NASNet-A large model architecture. """ default_cfg = default_cfgs['nasnetalarge'] - model = NASNetALarge(num_classes=1000, in_chans=in_chans, **kwargs) + model = NASNetALarge(num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 33b051ef..884894d9 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -38,11 +38,14 @@ default_cfgs = { 'resnest50d': _cfg( url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'), 'resnest101e': _cfg( - url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', input_size=(3, 256, 256)), + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), 'resnest200e': _cfg( - url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)), + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', + input_size=(3, 320, 320), pool_size=(10, 10)), 'resnest269e': _cfg( - url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)), + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', + input_size=(3, 416, 416), pool_size=(13, 13)), 'resnest50d_4s2x40d': _cfg( url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_4s2x40d-41d14ed0.pth', interpolation='bicubic'), diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 2f369e99..6b83421b 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -26,7 +26,7 @@ __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (3, 3), + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem', 'classifier': 'fc', diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 48b3e1de..fbbcf318 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -28,7 +28,7 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': (0, 0, 0), 'std': (1, 1, 1), - 'first_conv': 'layer0.conv1', 'classifier': 'head.fc', + 'first_conv': 'body.conv1', 'classifier': 'head.fc', **kwargs } @@ -41,13 +41,13 @@ default_cfgs = { 'tresnet_xl': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'), 'tresnet_m_448': _cfg( - input_size=(3, 448, 448), + input_size=(3, 448, 448), pool_size=(14, 14), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'), 'tresnet_l_448': _cfg( - input_size=(3, 448, 448), + input_size=(3, 448, 448), pool_size=(14, 14), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'), 'tresnet_xl_448': _cfg( - input_size=(3, 448, 448), + input_size=(3, 448, 448), pool_size=(14, 14), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth') } diff --git a/timm/models/xception.py b/timm/models/xception.py index cb98bbc9..f04dabfd 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -37,6 +37,7 @@ default_cfgs = { 'xception': { 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth', 'input_size': (3, 299, 299), + 'pool_size': (10, 10), 'crop_pct': 0.8975, 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5), From 3873ea710e9292b36f6b800bf153f6521d7568b6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 21 May 2020 15:51:47 -0700 Subject: [PATCH 3/7] Minor test change --- tests/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 65a7ebb3..fd99bb46 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -66,5 +66,5 @@ def test_model_default_cfgs(model_name, batch_size): input_size = tuple([min(x, 448) for x in input_size]) outputs = model.forward_features(torch.randn((batch_size, *input_size))) assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] - assert any([k.startswith(cfg['classifier']) for k in state_dict.keys()]), f'{classifier} not in model params' - assert any([k.startswith(cfg['first_conv']) for k in state_dict.keys()]), f'{first_conv} not in model params' + assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params' + assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params' From 20329f263033137d6cd4cfe87bc817a74d4e88f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 21 May 2020 16:49:46 -0700 Subject: [PATCH 4/7] Bring down test resolutions to see if we can at least do a fwd on the L2 models --- tests/test_models.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index fd99bb46..822e0f2f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,6 +3,10 @@ import torch from timm import list_models, create_model +MAX_FWD_SIZE = 320 +MAX_BWD_SIZE = 128 +MAX_FWD_FEAT_SIZE = 448 + @pytest.mark.timeout(120) @pytest.mark.parametrize('model_name', list_models()) @@ -13,9 +17,9 @@ def test_model_forward(model_name, batch_size): model.eval() input_size = model.default_cfg['input_size'] - if any([x > 448 for x in input_size]): + if any([x > MAX_FWD_SIZE for x in input_size]): # cap forward test at max res 448 * 448 to keep resource down - input_size = tuple([min(x, 448) for x in input_size]) + input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size]) inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) @@ -33,9 +37,9 @@ def test_model_backward(model_name, batch_size): model.eval() input_size = model.default_cfg['input_size'] - if any([x > 128 for x in input_size]): + if any([x > MAX_BWD_SIZE for x in input_size]): # cap backward test at 128 * 128 to keep resource usage down - input_size = tuple([min(x, 128) for x in input_size]) + input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size]) inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) outputs.mean().backward() @@ -61,9 +65,9 @@ def test_model_default_cfgs(model_name, batch_size): pool_size = cfg['pool_size'] input_size = model.default_cfg['input_size'] - if all([x <= 448 for x in input_size]): + if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and 'efficientnet_l2' not in model_name: # pool size only checked if default res <= 448 * 448 to keep resource down - input_size = tuple([min(x, 448) for x in input_size]) + input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size]) outputs = model.forward_features(torch.randn((batch_size, *input_size))) assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params' From 4212cd3b9ff184677e42a3943ab773db8024ea98 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 21 May 2020 18:55:10 -0700 Subject: [PATCH 5/7] Another attempt at getting Ubuntu test runner to work --- tests/test_models.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 822e0f2f..5c79dd2e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,15 +1,24 @@ import pytest import torch +import platform +import os +import fnmatch from timm import list_models, create_model -MAX_FWD_SIZE = 320 + +if 'GITHUB_ACTIONS' in os.environ and 'Linux' in platform.system(): + # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models + EXCLUDE_FILTERS = ['*efficientnet_l2*'] +else: + EXCLUDE_FILTERS = [] +MAX_FWD_SIZE = 384 MAX_BWD_SIZE = 128 MAX_FWD_FEAT_SIZE = 448 @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" @@ -28,7 +37,8 @@ def test_model_forward(model_name, batch_size): @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters='dla*')) # DLA models have an issue TBD +# DLA models have an issue TBD, add them to exclusions +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + ['dla*'])) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward(model_name, batch_size): """Run a single forward pass with each model""" @@ -65,7 +75,8 @@ def test_model_default_cfgs(model_name, batch_size): pool_size = cfg['pool_size'] input_size = model.default_cfg['input_size'] - if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and 'efficientnet_l2' not in model_name: + if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \ + not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]): # pool size only checked if default res <= 448 * 448 to keep resource down input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size]) outputs = model.forward_features(torch.randn((batch_size, *input_size))) From 4d13db538f6736473b21b52890b6df6dfdbaff7c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 21 May 2020 19:13:41 -0700 Subject: [PATCH 6/7] Add x48d ResNext101s to test exclude for ubuntu --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 5c79dd2e..02cb61bb 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,7 @@ from timm import list_models, create_model if 'GITHUB_ACTIONS' in os.environ and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models - EXCLUDE_FILTERS = ['*efficientnet_l2*'] + EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d'] else: EXCLUDE_FILTERS = [] MAX_FWD_SIZE = 384 From d79ac48626dce7758eddda9405096bb6e23d1cec Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 22 May 2020 14:42:43 -0700 Subject: [PATCH 7/7] Update sotabench.py --- sotabench.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sotabench.py b/sotabench.py index de3828f4..c394d062 100644 --- a/sotabench.py +++ b/sotabench.py @@ -126,6 +126,15 @@ model_list = [ _entry('skresnet34', 'SK-ResNet-34', '1903.06586'), _entry('skresnext50_32x4d', 'SKNet-50', '1903.06586'), + _entry('ecaresnetlight', 'ECA-ResNet-Light', '1910.03151', + model_desc='A tweaked ResNet50d with ECA attn.'), + _entry('ecaresnet50d', 'ECA-ResNet-50d', '1910.03151', + model_desc='A ResNet50d with ECA attn'), + _entry('ecaresnet101d', 'ECA-ResNet-101d', '1910.03151', + model_desc='A ResNet101d with ECA attn'), + + _entry('resnetblur50', 'ResNet-Blur-50', '1904.11486'), + _entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946', model_desc='Ported from official Google AI Tensorflow weights'), _entry('tf_efficientnet_b1', 'EfficientNet-B1 (AutoAugment)', '1905.11946',