Add support for tflite mnasnet pretrained weights and included spnasnet pretrained weights of my own.

* tensorflow 'SAME' padding support added to GenMobileNet models for tflite pretrained weights
* folded batch norm support (made batch norm optional and enable conv bias) for tflite pretrained weights
* add url for spnasnet1_00 weights that I recently trained
* fix SE reduction size for semnasnet models
pull/1/head
Ross Wightman 5 years ago
parent afb357ff68
commit 4663fc2132

@ -0,0 +1,39 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation,
groups, bias)
def forward(self, x):
ih, iw = x.size()[-2:]
kh, kw = self.weight.size()[-2:]
oh = math.ceil(ih / self.stride[0])
ow = math.ceil(iw / self.stride[1])
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
return F.conv2d(x, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
# helper method
def sconv2d(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', 0)
if isinstance(padding, str):
if padding.lower() == 'same':
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
# 'valid'
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)

@ -23,6 +23,7 @@ import torch.nn as nn
import torch.nn.functional as F
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from models.conv2d_same import sconv2d
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40',
@ -45,10 +46,12 @@ default_cfgs = {
'mnasnet0_50': _cfg(url=''),
'mnasnet0_75': _cfg(url=''),
'mnasnet1_00': _cfg(url=''),
'tflite_mnasnet1_00': _cfg(url='', interpolation='bicubic'),
'mnasnet1_40': _cfg(url=''),
'semnasnet0_50': _cfg(url=''),
'semnasnet0_75': _cfg(url=''),
'semnasnet1_00': _cfg(url=''),
'tflite_semnasnet1_00': _cfg(url='', interpolation='bicubic'),
'semnasnet1_40': _cfg(url=''),
'mnasnet_small': _cfg(url=''),
'mobilenetv1_1_00': _cfg(url=''),
@ -56,7 +59,7 @@ default_cfgs = {
'chamnetv1_1_00': _cfg(url=''),
'chamnetv2_1_00': _cfg(url=''),
'fbnetc_1_00': _cfg(url=''),
'spnasnet1_00': _cfg(url=''),
'spnasnet1_00': _cfg(url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet1_00-048bc3f4.pth?dl=1'),
}
_DEBUG = True
@ -184,11 +187,15 @@ def _decode_block_str(block_str):
return [deepcopy(block_args) for _ in range(num_repeat)]
def _get_padding(kernel_size, stride, dilation):
def _get_padding(kernel_size, stride, dilation=1):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
def _padding_arg(default, padding_same=False):
return 'SAME' if padding_same else default
def _decode_arch_args(string_list):
block_args = []
for block_str in string_list:
@ -219,12 +226,15 @@ class _BlockBuilder:
"""
def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False):
self.depth_multiplier = depth_multiplier
self.depth_divisor = depth_divisor
self.min_depth = min_depth
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
self.folded_bn = folded_bn
self.padding_same = padding_same
self.in_chs = None
def _round_channels(self, chs):
@ -236,6 +246,8 @@ class _BlockBuilder:
ba['out_chs'] = _round_channels(ba['out_chs'])
ba['bn_momentum'] = self.bn_momentum
ba['bn_eps'] = self.bn_eps
ba['folded_bn'] = self.folded_bn
ba['padding_same'] = self.padding_same
if _DEBUG:
print('args:', ba)
# could replace this with lambdas or functools binding if variety increases
@ -320,29 +332,37 @@ def _initialize_weight_default(m):
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False):
super(DepthwiseSeparableConv, self).__init__()
assert stride in [1, 2]
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.act_fn = act_fn
dw_padding = _padding_arg(kernel_size // 2, padding_same)
pw_padding = _padding_arg(0, padding_same)
self.conv_dw = nn.Conv2d(
self.conv_dw = sconv2d(
in_chs, in_chs, kernel_size,
stride=stride, padding=kernel_size // 2, groups=in_chs, bias=False)
self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
self.conv_pw = nn.Conv2d(in_chs, out_chs, 1, bias=False)
self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
stride=stride, padding=dw_padding, groups=in_chs, bias=folded_bn)
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=folded_bn)
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
def forward(self, x):
residual = x
x = self.conv_dw(x)
x = self.bn1(x)
if self.bn1 is not None:
x = self.bn1(x)
x = self.act_fn(x)
x = self.conv_pw(x)
x = self.bn2(x)
if self.bn2 is not None:
x = self.bn2(x)
if self.has_pw_act:
x = self.act_fn(x)
if self.has_residual:
x += residual
return x
@ -351,24 +371,28 @@ class DepthwiseSeparableConv(nn.Module):
class CascadeConv3x3(nn.Sequential):
# FIXME lifted from maskrcnn_benchmark blocks, haven't used yet
def __init__(self, in_chs, out_chs, stride, act_fn=F.relu, noskip=False,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False):
super(CascadeConv3x3, self).__init__()
assert stride in [1, 2]
self.has_residual = not noskip and (stride == 1 and in_chs == out_chs)
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.act_fn = act_fn
padding = _padding_arg(1, padding_same)
self.conv1 = nn.Conv2d(in_chs, in_chs, 3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
self.conv2 = nn.Conv2d(in_chs, out_chs, 3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
self.conv1 = sconv2d(in_chs, in_chs, 3, stride=stride, padding=padding, bias=folded_bn)
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
self.conv2 = sconv2d(in_chs, out_chs, 3, stride=1, padding=padding, bias=folded_bn)
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.bn1(x)
if self.bn1 is not None:
x = self.bn1(x)
x = self.act_fn(x)
x = self.conv2(x)
x = self.bn2(x)
if self.bn2 is not None:
x = self.bn2(x)
if self.has_residual:
x += residual
return x
@ -396,10 +420,10 @@ class ChannelShuffle(nn.Module):
class SqueezeExcite(nn.Module):
def __init__(self, in_chs, se_ratio=0.25, act_fn=F.relu):
def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu):
super(SqueezeExcite, self).__init__()
self.act_fn = act_fn
reduced_chs = max(1, int(in_chs * se_ratio))
reduced_chs = reduce_chs or in_chs
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
@ -419,41 +443,44 @@ class InvertedResidual(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False,
se_ratio=0., shuffle_type=None, pw_group=1,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False):
super(InvertedResidual, self).__init__()
mid_chs = int(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.act_fn = act_fn
dw_padding = _padding_arg(kernel_size // 2, padding_same)
pw_padding = _padding_arg(0, padding_same)
# Point-wise expansion
self.conv_pw = nn.Conv2d(in_chs, mid_chs, 1, groups=pw_group, bias=False)
self.bn1 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn)
self.bn1 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
self.shuffle_type = shuffle_type
if shuffle_type is not None:
self.shuffle = ChannelShuffle(pw_group)
# Depth-wise convolution
self.conv_dw = nn.Conv2d(
mid_chs, mid_chs, kernel_size, padding=kernel_size // 2,
stride=stride, groups=mid_chs, bias=False)
self.bn2 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
self.conv_dw = sconv2d(
mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=folded_bn)
self.bn2 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
# Squeeze-and-excitation
if self.has_se:
self.se = SqueezeExcite(mid_chs, se_ratio)
self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(in_chs * se_ratio)))
# Point-wise linear projection
self.conv_pwl = nn.Conv2d(mid_chs, out_chs, 1, groups=pw_group, bias=False)
self.bn3 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn)
self.bn3 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
def forward(self, x):
residual = x
# Point-wise expansion
x = self.conv_pw(x)
x = self.bn1(x)
if self.bn1 is not None:
x = self.bn1(x)
x = self.act_fn(x)
# FIXME haven't tried this yet
@ -463,7 +490,8 @@ class InvertedResidual(nn.Module):
# Depth-wise convolution
x = self.conv_dw(x)
x = self.bn2(x)
if self.bn2 is not None:
x = self.bn2(x)
x = self.act_fn(x)
# Squeeze-and-excitation
@ -472,7 +500,8 @@ class InvertedResidual(nn.Module):
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn3(x)
if self.bn3 is not None:
x = self.bn3(x)
if self.has_residual:
x += residual
@ -498,7 +527,7 @@ class GenMobileNet(nn.Module):
depth_multiplier=1.0, depth_divisor=8, min_depth=None,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False,
weight_init='goog'):
weight_init='goog', folded_bn=False, padding_same=False):
super(GenMobileNet, self).__init__()
self.num_classes = num_classes
self.depth_multiplier = depth_multiplier
@ -507,13 +536,15 @@ class GenMobileNet(nn.Module):
self.num_features = num_features
stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth)
self.conv_stem = nn.Conv2d(in_chans, stem_size, 3, padding=1, stride=2, bias=False)
self.bn1 = nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps)
self.conv_stem = sconv2d(
in_chans, stem_size, 3,
padding=_padding_arg(1, padding_same), stride=2, bias=folded_bn)
self.bn1 = None if folded_bn else nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps)
in_chs = stem_size
builder = _BlockBuilder(
depth_multiplier, depth_divisor, min_depth,
bn_momentum, bn_eps)
bn_momentum, bn_eps, folded_bn, padding_same)
self.blocks = nn.Sequential(*builder(in_chs, block_args))
in_chs = builder.in_chs
@ -521,8 +552,10 @@ class GenMobileNet(nn.Module):
self.conv_head = None
assert in_chs == self.num_features
else:
self.conv_head = nn.Conv2d(in_chs, self.num_features, 1, padding=0, stride=1, bias=False)
self.bn2 = nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
self.conv_head = sconv2d(
in_chs, self.num_features, 1,
padding=_padding_arg(0, padding_same), bias=folded_bn)
self.bn2 = None if folded_bn else nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features, self.num_classes)
@ -548,12 +581,14 @@ class GenMobileNet(nn.Module):
def forward_features(self, x, pool=True):
x = self.conv_stem(x)
x = self.bn1(x)
if self.bn1 is not None:
x = self.bn1(x)
x = self.act_fn(x)
x = self.blocks(x)
if self.conv_head is not None:
x = self.conv_head(x)
x = self.bn2(x)
if self.bn2 is not None:
x = self.bn2(x)
x = self.act_fn(x)
if pool:
x = self.global_pool(x)
@ -909,6 +944,19 @@ def mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
return model
def tflite_mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 1.0. """
default_cfg = default_cfgs['tflite_mnasnet1_00']
# these two args are for compat with tflite pretrained weights
kwargs['folded_bn'] = True
kwargs['padding_same'] = True
model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def mnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 1.4 """
default_cfg = default_cfgs['mnasnet1_40']
@ -949,6 +997,19 @@ def semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
return model
def tflite_semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
""" MNASNet A1, depth multiplier of 1.0. """
default_cfg = default_cfgs['tflite_semnasnet1_00']
# these two args are for compat with tflite pretrained weights
kwargs['folded_bn'] = True
kwargs['padding_same'] = True
model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def semnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
default_cfg = default_cfgs['semnasnet1_40']

@ -9,8 +9,8 @@ from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresn
from models.xception import xception
from models.pnasnet import pnasnet5large
from models.genmobilenet import \
mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40,\
semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, mnasnet_small,\
mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40, tflite_mnasnet1_00,\
semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, tflite_semnasnet1_00, mnasnet_small,\
mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00,\
spnasnet1_00

Loading…
Cancel
Save