From 4663fc2132589296688503fd4606c0ccf148f81b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 28 Apr 2019 17:45:07 -0700 Subject: [PATCH] Add support for tflite mnasnet pretrained weights and included spnasnet pretrained weights of my own. * tensorflow 'SAME' padding support added to GenMobileNet models for tflite pretrained weights * folded batch norm support (made batch norm optional and enable conv bias) for tflite pretrained weights * add url for spnasnet1_00 weights that I recently trained * fix SE reduction size for semnasnet models --- models/conv2d_same.py | 39 +++++++++++ models/genmobilenet.py | 145 ++++++++++++++++++++++++++++------------ models/model_factory.py | 4 +- 3 files changed, 144 insertions(+), 44 deletions(-) create mode 100644 models/conv2d_same.py diff --git a/models/conv2d_same.py b/models/conv2d_same.py new file mode 100644 index 00000000..6d0b9e09 --- /dev/null +++ b/models/conv2d_same.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, + groups, bias) + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + oh = math.ceil(ih / self.stride[0]) + ow = math.ceil(iw / self.stride[1]) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2]) + return F.conv2d(x, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +# helper method +def sconv2d(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', 0) + if isinstance(padding, str): + if padding.lower() == 'same': + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + # 'valid' + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) diff --git a/models/genmobilenet.py b/models/genmobilenet.py index 842ba33f..53babe7a 100644 --- a/models/genmobilenet.py +++ b/models/genmobilenet.py @@ -23,6 +23,7 @@ import torch.nn as nn import torch.nn.functional as F from models.helpers import load_pretrained from models.adaptive_avgmax_pool import SelectAdaptivePool2d +from models.conv2d_same import sconv2d from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40', @@ -45,10 +46,12 @@ default_cfgs = { 'mnasnet0_50': _cfg(url=''), 'mnasnet0_75': _cfg(url=''), 'mnasnet1_00': _cfg(url=''), + 'tflite_mnasnet1_00': _cfg(url='', interpolation='bicubic'), 'mnasnet1_40': _cfg(url=''), 'semnasnet0_50': _cfg(url=''), 'semnasnet0_75': _cfg(url=''), 'semnasnet1_00': _cfg(url=''), + 'tflite_semnasnet1_00': _cfg(url='', interpolation='bicubic'), 'semnasnet1_40': _cfg(url=''), 'mnasnet_small': _cfg(url=''), 'mobilenetv1_1_00': _cfg(url=''), @@ -56,7 +59,7 @@ default_cfgs = { 'chamnetv1_1_00': _cfg(url=''), 'chamnetv2_1_00': _cfg(url=''), 'fbnetc_1_00': _cfg(url=''), - 'spnasnet1_00': _cfg(url=''), + 'spnasnet1_00': _cfg(url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet1_00-048bc3f4.pth?dl=1'), } _DEBUG = True @@ -184,11 +187,15 @@ def _decode_block_str(block_str): return [deepcopy(block_args) for _ in range(num_repeat)] -def _get_padding(kernel_size, stride, dilation): +def _get_padding(kernel_size, stride, dilation=1): padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding +def _padding_arg(default, padding_same=False): + return 'SAME' if padding_same else default + + def _decode_arch_args(string_list): block_args = [] for block_str in string_list: @@ -219,12 +226,15 @@ class _BlockBuilder: """ def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None, - bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT): + bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, + folded_bn=False, padding_same=False): self.depth_multiplier = depth_multiplier self.depth_divisor = depth_divisor self.min_depth = min_depth self.bn_momentum = bn_momentum self.bn_eps = bn_eps + self.folded_bn = folded_bn + self.padding_same = padding_same self.in_chs = None def _round_channels(self, chs): @@ -236,6 +246,8 @@ class _BlockBuilder: ba['out_chs'] = _round_channels(ba['out_chs']) ba['bn_momentum'] = self.bn_momentum ba['bn_eps'] = self.bn_eps + ba['folded_bn'] = self.folded_bn + ba['padding_same'] = self.padding_same if _DEBUG: print('args:', ba) # could replace this with lambdas or functools binding if variety increases @@ -320,29 +332,37 @@ def _initialize_weight_default(m): class DepthwiseSeparableConv(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, stride=1, act_fn=F.relu, noskip=False, pw_act=False, - bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT): + bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, + folded_bn=False, padding_same=False): super(DepthwiseSeparableConv, self).__init__() assert stride in [1, 2] self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv self.act_fn = act_fn + dw_padding = _padding_arg(kernel_size // 2, padding_same) + pw_padding = _padding_arg(0, padding_same) - self.conv_dw = nn.Conv2d( + self.conv_dw = sconv2d( in_chs, in_chs, kernel_size, - stride=stride, padding=kernel_size // 2, groups=in_chs, bias=False) - self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps) - self.conv_pw = nn.Conv2d(in_chs, out_chs, 1, bias=False) - self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) + stride=stride, padding=dw_padding, groups=in_chs, bias=folded_bn) + self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps) + self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=folded_bn) + self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) def forward(self, x): residual = x + x = self.conv_dw(x) - x = self.bn1(x) + if self.bn1 is not None: + x = self.bn1(x) x = self.act_fn(x) + x = self.conv_pw(x) - x = self.bn2(x) + if self.bn2 is not None: + x = self.bn2(x) if self.has_pw_act: x = self.act_fn(x) + if self.has_residual: x += residual return x @@ -351,24 +371,28 @@ class DepthwiseSeparableConv(nn.Module): class CascadeConv3x3(nn.Sequential): # FIXME lifted from maskrcnn_benchmark blocks, haven't used yet def __init__(self, in_chs, out_chs, stride, act_fn=F.relu, noskip=False, - bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT): + bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, + folded_bn=False, padding_same=False): super(CascadeConv3x3, self).__init__() assert stride in [1, 2] - self.has_residual = not noskip and (stride == 1 and in_chs == out_chs) + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.act_fn = act_fn + padding = _padding_arg(1, padding_same) - self.conv1 = nn.Conv2d(in_chs, in_chs, 3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps) - self.conv2 = nn.Conv2d(in_chs, out_chs, 3, stride=1, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) + self.conv1 = sconv2d(in_chs, in_chs, 3, stride=stride, padding=padding, bias=folded_bn) + self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps) + self.conv2 = sconv2d(in_chs, out_chs, 3, stride=1, padding=padding, bias=folded_bn) + self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) def forward(self, x): residual = x x = self.conv1(x) - x = self.bn1(x) + if self.bn1 is not None: + x = self.bn1(x) x = self.act_fn(x) x = self.conv2(x) - x = self.bn2(x) + if self.bn2 is not None: + x = self.bn2(x) if self.has_residual: x += residual return x @@ -396,10 +420,10 @@ class ChannelShuffle(nn.Module): class SqueezeExcite(nn.Module): - def __init__(self, in_chs, se_ratio=0.25, act_fn=F.relu): + def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu): super(SqueezeExcite, self).__init__() self.act_fn = act_fn - reduced_chs = max(1, int(in_chs * se_ratio)) + reduced_chs = reduce_chs or in_chs self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) @@ -419,41 +443,44 @@ class InvertedResidual(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False, se_ratio=0., shuffle_type=None, pw_group=1, - bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT): + bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, + folded_bn=False, padding_same=False): super(InvertedResidual, self).__init__() mid_chs = int(in_chs * exp_ratio) self.has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.act_fn = act_fn + dw_padding = _padding_arg(kernel_size // 2, padding_same) + pw_padding = _padding_arg(0, padding_same) # Point-wise expansion - self.conv_pw = nn.Conv2d(in_chs, mid_chs, 1, groups=pw_group, bias=False) - self.bn1 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps) + self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn) + self.bn1 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps) self.shuffle_type = shuffle_type if shuffle_type is not None: self.shuffle = ChannelShuffle(pw_group) # Depth-wise convolution - self.conv_dw = nn.Conv2d( - mid_chs, mid_chs, kernel_size, padding=kernel_size // 2, - stride=stride, groups=mid_chs, bias=False) - self.bn2 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps) + self.conv_dw = sconv2d( + mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=folded_bn) + self.bn2 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps) # Squeeze-and-excitation if self.has_se: - self.se = SqueezeExcite(mid_chs, se_ratio) + self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(in_chs * se_ratio))) # Point-wise linear projection - self.conv_pwl = nn.Conv2d(mid_chs, out_chs, 1, groups=pw_group, bias=False) - self.bn3 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) + self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn) + self.bn3 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) def forward(self, x): residual = x # Point-wise expansion x = self.conv_pw(x) - x = self.bn1(x) + if self.bn1 is not None: + x = self.bn1(x) x = self.act_fn(x) # FIXME haven't tried this yet @@ -463,7 +490,8 @@ class InvertedResidual(nn.Module): # Depth-wise convolution x = self.conv_dw(x) - x = self.bn2(x) + if self.bn2 is not None: + x = self.bn2(x) x = self.act_fn(x) # Squeeze-and-excitation @@ -472,7 +500,8 @@ class InvertedResidual(nn.Module): # Point-wise linear projection x = self.conv_pwl(x) - x = self.bn3(x) + if self.bn3 is not None: + x = self.bn3(x) if self.has_residual: x += residual @@ -498,7 +527,7 @@ class GenMobileNet(nn.Module): depth_multiplier=1.0, depth_divisor=8, min_depth=None, bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False, - weight_init='goog'): + weight_init='goog', folded_bn=False, padding_same=False): super(GenMobileNet, self).__init__() self.num_classes = num_classes self.depth_multiplier = depth_multiplier @@ -507,13 +536,15 @@ class GenMobileNet(nn.Module): self.num_features = num_features stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth) - self.conv_stem = nn.Conv2d(in_chans, stem_size, 3, padding=1, stride=2, bias=False) - self.bn1 = nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps) + self.conv_stem = sconv2d( + in_chans, stem_size, 3, + padding=_padding_arg(1, padding_same), stride=2, bias=folded_bn) + self.bn1 = None if folded_bn else nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps) in_chs = stem_size builder = _BlockBuilder( depth_multiplier, depth_divisor, min_depth, - bn_momentum, bn_eps) + bn_momentum, bn_eps, folded_bn, padding_same) self.blocks = nn.Sequential(*builder(in_chs, block_args)) in_chs = builder.in_chs @@ -521,8 +552,10 @@ class GenMobileNet(nn.Module): self.conv_head = None assert in_chs == self.num_features else: - self.conv_head = nn.Conv2d(in_chs, self.num_features, 1, padding=0, stride=1, bias=False) - self.bn2 = nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps) + self.conv_head = sconv2d( + in_chs, self.num_features, 1, + padding=_padding_arg(0, padding_same), bias=folded_bn) + self.bn2 = None if folded_bn else nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.classifier = nn.Linear(self.num_features, self.num_classes) @@ -548,12 +581,14 @@ class GenMobileNet(nn.Module): def forward_features(self, x, pool=True): x = self.conv_stem(x) - x = self.bn1(x) + if self.bn1 is not None: + x = self.bn1(x) x = self.act_fn(x) x = self.blocks(x) if self.conv_head is not None: x = self.conv_head(x) - x = self.bn2(x) + if self.bn2 is not None: + x = self.bn2(x) x = self.act_fn(x) if pool: x = self.global_pool(x) @@ -909,6 +944,19 @@ def mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): return model +def tflite_mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + default_cfg = default_cfgs['tflite_mnasnet1_00'] + # these two args are for compat with tflite pretrained weights + kwargs['folded_bn'] = True + kwargs['padding_same'] = True + model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + def mnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.4 """ default_cfg = default_cfgs['mnasnet1_40'] @@ -949,6 +997,19 @@ def semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): return model +def tflite_semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MNASNet A1, depth multiplier of 1.0. """ + default_cfg = default_cfgs['tflite_semnasnet1_00'] + # these two args are for compat with tflite pretrained weights + kwargs['folded_bn'] = True + kwargs['padding_same'] = True + model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + def semnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ default_cfg = default_cfgs['semnasnet1_40'] diff --git a/models/model_factory.py b/models/model_factory.py index 64c578ac..902804ce 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -9,8 +9,8 @@ from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresn from models.xception import xception from models.pnasnet import pnasnet5large from models.genmobilenet import \ - mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40,\ - semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, mnasnet_small,\ + mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40, tflite_mnasnet1_00,\ + semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, tflite_semnasnet1_00, mnasnet_small,\ mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00,\ spnasnet1_00