Merge branch 'master' into densenet_update_and_more

pull/155/head
Ross Wightman 5 years ago
commit 0ea53cecc3

@ -126,6 +126,15 @@ model_list = [
_entry('skresnet34', 'SK-ResNet-34', '1903.06586'),
_entry('skresnext50_32x4d', 'SKNet-50', '1903.06586'),
_entry('ecaresnetlight', 'ECA-ResNet-Light', '1910.03151',
model_desc='A tweaked ResNet50d with ECA attn.'),
_entry('ecaresnet50d', 'ECA-ResNet-50d', '1910.03151',
model_desc='A ResNet50d with ECA attn'),
_entry('ecaresnet101d', 'ECA-ResNet-101d', '1910.03151',
model_desc='A ResNet101d with ECA attn'),
_entry('resnetblur50', 'ResNet-Blur-50', '1904.11486'),
_entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946',
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_b1', 'EfficientNet-B1 (AutoAugment)', '1905.11946',

@ -1,19 +0,0 @@
import pytest
import torch
from timm import list_models, create_model
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*'))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False)
model.eval()
inputs = torch.randn((batch_size, *model.default_cfg['input_size']))
outputs = model(inputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

@ -0,0 +1,85 @@
import pytest
import torch
import platform
import os
import fnmatch
from timm import list_models, create_model
if 'GITHUB_ACTIONS' in os.environ and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d']
else:
EXCLUDE_FILTERS = []
MAX_FWD_SIZE = 384
MAX_BWD_SIZE = 128
MAX_FWD_FEAT_SIZE = 448
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False)
model.eval()
input_size = model.default_cfg['input_size']
if any([x > MAX_FWD_SIZE for x in input_size]):
# cap forward test at max res 448 * 448 to keep resource down
input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size])
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
# DLA models have an issue TBD, add them to exclusions
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + ['dla*']))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward(model_name, batch_size):
"""Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False, num_classes=42)
num_params = sum([x.numel() for x in model.parameters()])
model.eval()
input_size = model.default_cfg['input_size']
if any([x > MAX_BWD_SIZE for x in input_size]):
# cap backward test at 128 * 128 to keep resource usage down
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
outputs.mean().backward()
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models())
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs(model_name, batch_size):
"""Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False)
model.eval()
state_dict = model.state_dict()
cfg = model.default_cfg
classifier = cfg['classifier']
first_conv = cfg['first_conv']
pool_size = cfg['pool_size']
input_size = model.default_cfg['input_size']
if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
# pool size only checked if default res <= 448 * 448 to keep resource down
input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size])
outputs = model.forward_features(torch.randn((batch_size, *input_size)))
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params'
assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params'

@ -19,6 +19,7 @@ from .hrnet import *
from .sknet import *
from .tresnet import *
from .resnest import *
from .regnet import *
from .registry import *
from .factory import create_model

@ -237,8 +237,11 @@ class DlaTree(nn.Module):
def forward(self, x, residual=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
residual = self.project(bottom) if self.project else bottom
# FIXME the way downsample / project are used here and residual is passed to next level up
# the tree, the residual is overridden and some project weights are thus never used and
# have no gradients. This appears to be an issue with the original model / weights.
bottom = self.downsample(x) if self.downsample is not None else x
residual = self.project(bottom) if self.project is not None else bottom
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, residual)
@ -355,7 +358,8 @@ def dla60_res2next(pretrained=None, num_classes=1000, in_chans=3, **kwargs):
@register_model
def dla34(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-34
default_cfg = default_cfgs['dla34']
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic, **kwargs)
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

@ -36,7 +36,7 @@ default_cfgs = {
'url': '',
'input_size': (3, 299, 299),
'crop_pct': 0.875,
'pool_size': (10, 10),
'pool_size': (5, 5),
'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN,
'std': IMAGENET_DEFAULT_STD,

@ -34,7 +34,7 @@ def _cfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv1', 'classifier': 'fc',
'first_conv': 'conv1', 'classifier': 'classifier',
**kwargs
}

@ -14,7 +14,7 @@ def _cfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'conv1', 'classifier': 'fc',
'first_conv': 'Conv2d_1a_3x3', 'classifier': 'fc',
**kwargs
}

@ -3,10 +3,10 @@ from torch import nn as nn
class SEModule(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = max(channels // reduction, 8)
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.act = act_layer(inplace=True)

@ -21,7 +21,7 @@ __all__ = ['MobileNetV3']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',

@ -18,7 +18,7 @@ default_cfgs = {
'mean': (0.5, 0.5, 0.5),
'std': (0.5, 0.5, 0.5),
'num_classes': 1001,
'first_conv': 'conv_0.conv',
'first_conv': 'conv0.conv',
'classifier': 'last_linear',
},
}
@ -613,7 +613,7 @@ def nasnetalarge(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""NASNet-A large model architecture.
"""
default_cfg = default_cfgs['nasnetalarge']
model = NASNetALarge(num_classes=1000, in_chans=in_chans, **kwargs)
model = NASNetALarge(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

@ -0,0 +1,485 @@
"""RegNet
Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678
Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here)
and cleaned up with more descriptive variable names.
Weights from original impl have been modified
* first layer from BGR -> RGB as most PyTorch models are
* removed training specific dict entries from checkpoints and keep model state_dict only
* remap names to match the ones here
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .registry import register_model
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, AvgPool2dSame, ConvBnAct, SEModule
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def _mcfg(**kwargs):
cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32)
cfg.update(**kwargs)
return cfg
# Model FLOPS = three trailing digits * 10^8
model_cfgs = dict(
x_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13),
x_004=_mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22),
x_006=_mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16),
x_008=_mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16),
x_016=_mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18),
x_032=_mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25),
x_040=_mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23),
x_064=_mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17),
x_080=_mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23),
x_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19),
x_160=_mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22),
x_320=_mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23),
y_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25),
y_004=_mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25),
y_006=_mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25),
y_008=_mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25),
y_016=_mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25),
y_032=_mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25),
y_040=_mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25),
y_064=_mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25),
y_080=_mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25),
y_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25),
y_160=_mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25),
y_320=_mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25),
)
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
}
default_cfgs = dict(
x_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth'),
x_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth'),
x_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth'),
x_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth'),
x_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth'),
x_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth'),
x_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth'),
x_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth'),
x_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth'),
x_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'),
x_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'),
x_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'),
y_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'),
y_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'),
y_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'),
y_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'),
y_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'),
y_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_032-62b47782.pth'),
y_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'),
y_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'),
y_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'),
y_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'),
y_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth'),
y_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
)
def quantize_float(f, q):
"""Converts a float to closest non-zero int divisible by q."""
return int(round(f / q) * q)
def adjust_widths_groups_comp(widths, bottle_ratios, groups):
"""Adjusts the compatibility of widths and groups."""
bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)]
groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)]
bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)]
widths = [int(w_bot / b) for w_bot, b in zip(bottleneck_widths, bottle_ratios)]
return widths, groups
def generate_regnet(width_slope, width_initial, width_mult, depth, q=8):
"""Generates per block widths from RegNet parameters."""
assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % q == 0
widths_cont = np.arange(depth) * width_slope + width_initial
width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult))
widths = width_initial * np.power(width_mult, width_exps)
widths = np.round(np.divide(widths, q)) * q
num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1
widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
return widths, num_stages, max_stage, widths_cont
class Bottleneck(nn.Module):
""" RegNet Bottleneck
This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from
after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
"""
def __init__(self, in_chs, out_chs, stride=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25,
dilation=1, first_dilation=None, downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
aa_layer=None, drop_block=None, drop_path=None):
super(Bottleneck, self).__init__()
bottleneck_chs = int(round(out_chs * bottleneck_ratio))
groups = bottleneck_chs // group_width
first_dilation = first_dilation or dilation
cargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block)
self.conv1 = ConvBnAct(in_chs, bottleneck_chs, kernel_size=1, **cargs)
self.conv2 = ConvBnAct(
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=first_dilation,
groups=groups, **cargs)
if se_ratio:
se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, reduction_channels=se_channels)
else:
self.se = None
cargs['act_layer'] = None
self.conv3 = ConvBnAct(bottleneck_chs, out_chs, kernel_size=1, **cargs)
self.act3 = act_layer(inplace=True)
self.downsample = downsample
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.conv2(x)
if self.se is not None:
x = self.se(x)
x = self.conv3(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None:
shortcut = self.downsample(shortcut)
x += shortcut
x = self.act3(x)
return x
def downsample_conv(
in_chs, out_chs, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
return ConvBnAct(
in_chs, out_chs, kernel_size, stride=stride, dilation=first_dilation, norm_layer=norm_layer, act_layer=None)
def downsample_avg(
in_chs, out_chs, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1
pool = nn.Identity()
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
return nn.Sequential(*[
pool, ConvBnAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, act_layer=None)])
class RegStage(nn.Module):
"""Stage (sequence of blocks w/ the same output shape)."""
def __init__(self, in_chs, out_chs, stride, depth, block_fn, bottle_ratio, group_width, se_ratio):
super(RegStage, self).__init__()
block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
for i in range(depth):
block_stride = stride if i == 0 else 1
block_in_chs = in_chs if i == 0 else out_chs
if (block_in_chs != out_chs) or (block_stride != 1):
proj_block = downsample_conv(block_in_chs, out_chs, 1, stride)
else:
proj_block = None
name = "b{}".format(i + 1)
self.add_module(
name, block_fn(
block_in_chs, out_chs, block_stride, bottle_ratio, group_width, se_ratio,
downsample=proj_block, **block_kwargs)
)
def forward(self, x):
for block in self.children():
x = block(x)
return x
class ClassifierHead(nn.Module):
"""Head."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
if num_classes > 0:
self.fc = nn.Linear(in_chs, num_classes, bias=True)
else:
self.fc = nn.Identity()
def forward(self, x):
x = self.global_pool(x).flatten(1)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
class RegNet(nn.Module):
"""RegNet model.
Paper: https://arxiv.org/abs/2003.13678
Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
"""
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.,
zero_init_last_bn=True):
super().__init__()
# TODO add drop block, drop path, anti-aliasing, custom bn/act args
self.num_classes = num_classes
self.drop_rate = drop_rate
# Construct the stem
stem_width = cfg['stem_width']
self.stem = ConvBnAct(in_chans, stem_width, 3, stride=2)
# Construct the stages
block_fn = Bottleneck
prev_width = stem_width
stage_params = self._get_stage_params(cfg)
se_ratio = cfg['se_ratio']
for i, (d, w, s, br, gw) in enumerate(stage_params):
self.add_module(
"s{}".format(i + 1), RegStage(prev_width, w, s, d, block_fn, br, gw, se_ratio))
prev_width = w
# Construct the head
self.num_features = prev_width
self.head = ClassifierHead(
in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
nn.init.zeros_(m.bias)
if zero_init_last_bn:
for m in self.modules():
if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn()
def _get_stage_params(self, cfg, stride=2):
# Generate RegNet ws per block
w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth']
widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
# Convert to per stage format
stage_widths, stage_depths = np.unique(widths, return_counts=True)
# Use the same group width, bottleneck mult and stride for each stage
stage_groups = [cfg['group_w'] for _ in range(num_stages)]
stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)]
stage_strides = [stride for _ in range(num_stages)]
# FIXME add dilation / output_stride support
# Adjust the compatibility of ws and gws
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
stage_params = list(zip(stage_depths, stage_widths, stage_strides, stage_bottle_ratios, stage_groups))
return stage_params
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
for block in list(self.children())[:-1]:
x = block(x)
return x
def forward(self, x):
for block in self.children():
x = block(x)
return x
def _regnet(variant, pretrained, **kwargs):
load_strict = True
model_class = RegNet
if kwargs.pop('features_only', False):
assert False, 'Not Implemented' # TODO
load_strict = False
kwargs.pop('num_classes', 0)
model_cfg = model_cfgs[variant]
default_cfg = default_cfgs[variant]
model = model_class(model_cfg, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(
model, default_cfg,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
return model
@register_model
def regnetx_002(pretrained=False, **kwargs):
"""RegNetX-200MF"""
return _regnet('x_002', pretrained, **kwargs)
@register_model
def regnetx_004(pretrained=False, **kwargs):
"""RegNetX-400MF"""
return _regnet('x_004', pretrained, **kwargs)
@register_model
def regnetx_006(pretrained=False, **kwargs):
"""RegNetX-600MF"""
return _regnet('x_006', pretrained, **kwargs)
@register_model
def regnetx_008(pretrained=False, **kwargs):
"""RegNetX-800MF"""
return _regnet('x_008', pretrained, **kwargs)
@register_model
def regnetx_016(pretrained=False, **kwargs):
"""RegNetX-1.6GF"""
return _regnet('x_016', pretrained, **kwargs)
@register_model
def regnetx_032(pretrained=False, **kwargs):
"""RegNetX-3.2GF"""
return _regnet('x_032', pretrained, **kwargs)
@register_model
def regnetx_040(pretrained=False, **kwargs):
"""RegNetX-4.0GF"""
return _regnet('x_040', pretrained, **kwargs)
@register_model
def regnetx_064(pretrained=False, **kwargs):
"""RegNetX-6.4GF"""
return _regnet('x_064', pretrained, **kwargs)
@register_model
def regnetx_080(pretrained=False, **kwargs):
"""RegNetX-8.0GF"""
return _regnet('x_080', pretrained, **kwargs)
@register_model
def regnetx_120(pretrained=False, **kwargs):
"""RegNetX-12GF"""
return _regnet('x_120', pretrained, **kwargs)
@register_model
def regnetx_160(pretrained=False, **kwargs):
"""RegNetX-16GF"""
return _regnet('x_160', pretrained, **kwargs)
@register_model
def regnetx_320(pretrained=False, **kwargs):
"""RegNetX-32GF"""
return _regnet('x_320', pretrained, **kwargs)
@register_model
def regnety_002(pretrained=False, **kwargs):
"""RegNetY-200MF"""
return _regnet('y_002', pretrained, **kwargs)
@register_model
def regnety_004(pretrained=False, **kwargs):
"""RegNetY-400MF"""
return _regnet('y_004', pretrained, **kwargs)
@register_model
def regnety_006(pretrained=False, **kwargs):
"""RegNetY-600MF"""
return _regnet('y_006', pretrained, **kwargs)
@register_model
def regnety_008(pretrained=False, **kwargs):
"""RegNetY-800MF"""
return _regnet('y_008', pretrained, **kwargs)
@register_model
def regnety_016(pretrained=False, **kwargs):
"""RegNetY-1.6GF"""
return _regnet('y_016', pretrained, **kwargs)
@register_model
def regnety_032(pretrained=False, **kwargs):
"""RegNetY-3.2GF"""
return _regnet('y_032', pretrained, **kwargs)
@register_model
def regnety_040(pretrained=False, **kwargs):
"""RegNetY-4.0GF"""
return _regnet('y_040', pretrained, **kwargs)
@register_model
def regnety_064(pretrained=False, **kwargs):
"""RegNetY-6.4GF"""
return _regnet('y_064', pretrained, **kwargs)
@register_model
def regnety_080(pretrained=False, **kwargs):
"""RegNetY-8.0GF"""
return _regnet('y_080', pretrained, **kwargs)
@register_model
def regnety_120(pretrained=False, **kwargs):
"""RegNetY-12GF"""
return _regnet('y_120', pretrained, **kwargs)
@register_model
def regnety_160(pretrained=False, **kwargs):
"""RegNetY-16GF"""
return _regnet('y_160', pretrained, **kwargs)
@register_model
def regnety_320(pretrained=False, **kwargs):
"""RegNetY-32GF"""
return _regnet('y_320', pretrained, **kwargs)

@ -38,11 +38,14 @@ default_cfgs = {
'resnest50d': _cfg(
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'),
'resnest101e': _cfg(
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', input_size=(3, 256, 256)),
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'resnest200e': _cfg(
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)),
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth',
input_size=(3, 320, 320), pool_size=(10, 10)),
'resnest269e': _cfg(
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)),
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth',
input_size=(3, 416, 416), pool_size=(13, 13)),
'resnest50d_4s2x40d': _cfg(
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_4s2x40d-41d14ed0.pth',
interpolation='bicubic'),

@ -25,7 +25,7 @@ __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (3, 3),
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem', 'classifier': 'fc',

@ -30,7 +30,7 @@ def _cfg(url='', **kwargs):
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': (0, 0, 0), 'std': (1, 1, 1),
'first_conv': 'layer0.conv1', 'classifier': 'head.fc',
'first_conv': 'body.conv1', 'classifier': 'head.fc',
**kwargs
}
@ -43,13 +43,13 @@ default_cfgs = {
'tresnet_xl': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'),
'tresnet_m_448': _cfg(
input_size=(3, 448, 448),
input_size=(3, 448, 448), pool_size=(14, 14),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'),
'tresnet_l_448': _cfg(
input_size=(3, 448, 448),
input_size=(3, 448, 448), pool_size=(14, 14),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'),
'tresnet_xl_448': _cfg(
input_size=(3, 448, 448),
input_size=(3, 448, 448), pool_size=(14, 14),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth')
}

@ -35,6 +35,7 @@ default_cfgs = {
'xception': {
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth',
'input_size': (3, 299, 299),
'pool_size': (10, 10),
'crop_pct': 0.8975,
'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5),

Loading…
Cancel
Save