Morph mnasnet impl into a generic mobilenet that covers Mnasnet, MobileNetV1/V2, ChamNet, FBNet, and related
* add an alternate RMSprop opt that applies eps like TF * add bn params for passing through alternates and changing defaults to TF stylepull/1/head
parent
e9c7961efc
commit
bc264269c9
@ -0,0 +1,957 @@
|
|||||||
|
""" Generic MobileNet
|
||||||
|
|
||||||
|
A generic MobileNet class with building blocks to support a variety of models:
|
||||||
|
* MNasNet B1, A1 (SE), Small
|
||||||
|
* MobileNetV2
|
||||||
|
* FBNet-C (TODO A & B)
|
||||||
|
* ChamNet (TODO still guessing at architecture definition)
|
||||||
|
* ShuffleNetV2 (TODO add IR shuffle block)
|
||||||
|
* And likely more...
|
||||||
|
|
||||||
|
TODO not all combinations and variations have been tested. Currently working on training hyper-params...
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from models.helpers import load_pretrained
|
||||||
|
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
|
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
|
__all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40',
|
||||||
|
'semnasnet0_50', 'semnasnet0_75', 'semnasnet1_00', 'semnasnet1_40',
|
||||||
|
'mnasnet_small']
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
'mnasnet0_50': _cfg(url=''),
|
||||||
|
'mnasnet0_75': _cfg(url=''),
|
||||||
|
'mnasnet1_00': _cfg(url=''),
|
||||||
|
'mnasnet1_40': _cfg(url=''),
|
||||||
|
'semnasnet0_50': _cfg(url=''),
|
||||||
|
'semnasnet0_75': _cfg(url=''),
|
||||||
|
'semnasnet1_00': _cfg(url=''),
|
||||||
|
'semnasnet1_40': _cfg(url=''),
|
||||||
|
'mnasnet_small': _cfg(url=''),
|
||||||
|
'mobilenetv1_1_00': _cfg(url=''),
|
||||||
|
'mobilenetv2_1_00': _cfg(url=''),
|
||||||
|
'chamnetv1_1_00': _cfg(url=''),
|
||||||
|
'chamnetv2_1_00': _cfg(url=''),
|
||||||
|
'fbnetc_1_00': _cfg(url=''),
|
||||||
|
}
|
||||||
|
|
||||||
|
_DEBUG = True
|
||||||
|
|
||||||
|
# default args for PyTorch BN impl
|
||||||
|
_BN_MOMENTUM_PT_DEFAULT = 0.1
|
||||||
|
_BN_EPS_PT_DEFAULT = 1e-5
|
||||||
|
|
||||||
|
# defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
||||||
|
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
|
||||||
|
_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 # NOTE this varies, .99 or .9997 depending on ref
|
||||||
|
_BN_EPS_TF_DEFAULT = 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_bn_params(kwargs):
|
||||||
|
# NOTE kwargs passed as dict intentionally
|
||||||
|
bn_momentum_default = _BN_MOMENTUM_PT_DEFAULT
|
||||||
|
bn_eps_default = _BN_EPS_PT_DEFAULT
|
||||||
|
bn_tf = kwargs.pop('bn_tf', False)
|
||||||
|
if bn_tf:
|
||||||
|
bn_momentum_default = _BN_MOMENTUM_TF_DEFAULT
|
||||||
|
bn_eps_default = _BN_EPS_TF_DEFAULT
|
||||||
|
bn_momentum = kwargs.pop('bn_momentum', None)
|
||||||
|
bn_eps = kwargs.pop('bn_eps', None)
|
||||||
|
if bn_momentum is None:
|
||||||
|
bn_momentum = bn_momentum_default
|
||||||
|
if bn_eps is None:
|
||||||
|
bn_eps = bn_eps_default
|
||||||
|
return bn_momentum, bn_eps
|
||||||
|
|
||||||
|
|
||||||
|
def _round_channels(channels, depth_multiplier=1.0, depth_divisor=8, min_depth=None):
|
||||||
|
"""Round number of filters based on depth multiplier."""
|
||||||
|
if not depth_multiplier:
|
||||||
|
return channels
|
||||||
|
|
||||||
|
channels *= depth_multiplier
|
||||||
|
min_depth = min_depth or depth_divisor
|
||||||
|
new_channels = max(
|
||||||
|
int(channels + depth_divisor / 2) // depth_divisor * depth_divisor,
|
||||||
|
min_depth)
|
||||||
|
# Make sure that round down does not go down by more than 10%.
|
||||||
|
if new_channels < 0.9 * channels:
|
||||||
|
new_channels += depth_divisor
|
||||||
|
return new_channels
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_block_str(block_str):
|
||||||
|
""" Decode block definition string
|
||||||
|
|
||||||
|
Gets a list of block arg (dicts) through a string notation of arguments.
|
||||||
|
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
|
||||||
|
|
||||||
|
All args can exist in any order with the exception of the leading string which
|
||||||
|
is assumed to indicate the block type.
|
||||||
|
|
||||||
|
leading string - block type (
|
||||||
|
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act,
|
||||||
|
ca = Cascade3x3, and possibly more)
|
||||||
|
r - number of repeat blocks,
|
||||||
|
k - kernel size,
|
||||||
|
s - strides (1-9),
|
||||||
|
e - expansion ratio,
|
||||||
|
c - output channels,
|
||||||
|
se - squeeze/excitation ratio
|
||||||
|
Args:
|
||||||
|
block_str: a string representation of block arguments.
|
||||||
|
Returns:
|
||||||
|
A list of block args (dicts)
|
||||||
|
Raises:
|
||||||
|
ValueError: if the string def not properly specified (TODO)
|
||||||
|
"""
|
||||||
|
assert isinstance(block_str, str)
|
||||||
|
ops = block_str.split('_')
|
||||||
|
block_type = ops[0] # take the block type off the front
|
||||||
|
ops = ops[1:]
|
||||||
|
options = {}
|
||||||
|
for op in ops:
|
||||||
|
splits = re.split(r'(\d.*)', op)
|
||||||
|
if len(splits) >= 2:
|
||||||
|
key, value = splits[:2]
|
||||||
|
options[key] = value
|
||||||
|
|
||||||
|
# FIXME validate args and throw
|
||||||
|
|
||||||
|
num_repeat = int(options['r'])
|
||||||
|
# each type of block has different valid arguments, fill accordingly
|
||||||
|
if block_type == 'ir':
|
||||||
|
block_args = dict(
|
||||||
|
block_type=block_type,
|
||||||
|
kernel_size=int(options['k']),
|
||||||
|
out_chs=int(options['c']),
|
||||||
|
exp_ratio=int(options['e']),
|
||||||
|
se_ratio=float(options['se']) if 'se' in options else None,
|
||||||
|
stride=int(options['s']),
|
||||||
|
noskip=('noskip' in block_str),
|
||||||
|
)
|
||||||
|
if 'g' in options:
|
||||||
|
block_args['pw_group'] = options['g']
|
||||||
|
if options['g'] > 1:
|
||||||
|
block_args['shuffle_type'] = 'mid'
|
||||||
|
elif block_type == 'ca':
|
||||||
|
block_args = dict(
|
||||||
|
block_type=block_type,
|
||||||
|
out_chs=int(options['c']),
|
||||||
|
stride=int(options['s']),
|
||||||
|
noskip=('noskip' in block_str),
|
||||||
|
)
|
||||||
|
elif block_type == 'ds' or block_type == 'dsa':
|
||||||
|
block_args = dict(
|
||||||
|
block_type=block_type,
|
||||||
|
kernel_size=int(options['k']),
|
||||||
|
out_chs=int(options['c']),
|
||||||
|
stride=int(options['s']),
|
||||||
|
noskip=block_type == 'dsa' or 'noskip' in block_str,
|
||||||
|
pw_act=block_type == 'dsa',
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, 'Unknown block type (%s)' % block_type
|
||||||
|
|
||||||
|
# return a list of block args expanded by num_repeat
|
||||||
|
return [deepcopy(block_args) for _ in range(num_repeat)]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_padding(kernel_size, stride, dilation):
|
||||||
|
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||||
|
return padding
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_arch_args(string_list):
|
||||||
|
block_args = []
|
||||||
|
for block_str in string_list:
|
||||||
|
block_args.append(_decode_block_str(block_str))
|
||||||
|
return block_args
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_arch_def(arch_def):
|
||||||
|
arch_args = []
|
||||||
|
for stack_idx, block_strings in enumerate(arch_def):
|
||||||
|
assert isinstance(block_strings, list)
|
||||||
|
stack_args = []
|
||||||
|
for block_str in block_strings:
|
||||||
|
assert isinstance(block_str, str)
|
||||||
|
stack_args.extend(_decode_block_str(block_str))
|
||||||
|
arch_args.append(stack_args)
|
||||||
|
return arch_args
|
||||||
|
|
||||||
|
|
||||||
|
class _BlockBuilder:
|
||||||
|
""" Build Trunk Blocks
|
||||||
|
|
||||||
|
This ended up being somewhat of a cross between
|
||||||
|
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
|
||||||
|
and
|
||||||
|
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
||||||
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
||||||
|
self.depth_multiplier = depth_multiplier
|
||||||
|
self.depth_divisor = depth_divisor
|
||||||
|
self.min_depth = min_depth
|
||||||
|
self.bn_momentum = bn_momentum
|
||||||
|
self.bn_eps = bn_eps
|
||||||
|
self.in_chs = None
|
||||||
|
|
||||||
|
def _round_channels(self, chs):
|
||||||
|
return _round_channels(chs, self.depth_multiplier, self.depth_divisor, self.min_depth)
|
||||||
|
|
||||||
|
def _make_block(self, ba):
|
||||||
|
bt = ba.pop('block_type')
|
||||||
|
ba['in_chs'] = self.in_chs
|
||||||
|
ba['out_chs'] = _round_channels(ba['out_chs'])
|
||||||
|
ba['bn_momentum'] = self.bn_momentum
|
||||||
|
ba['bn_eps'] = self.bn_eps
|
||||||
|
if _DEBUG:
|
||||||
|
print('args:', ba)
|
||||||
|
# could replace this with lambdas or functools binding if variety increases
|
||||||
|
if bt == 'ir':
|
||||||
|
block = InvertedResidual(**ba)
|
||||||
|
elif bt == 'ds' or bt == 'dsa':
|
||||||
|
block = DepthwiseSeparableConv(**ba)
|
||||||
|
elif bt == 'ca':
|
||||||
|
block = CascadeConv3x3(**ba)
|
||||||
|
else:
|
||||||
|
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||||
|
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
||||||
|
return block
|
||||||
|
|
||||||
|
def _make_stack(self, stack_args):
|
||||||
|
blocks = []
|
||||||
|
# each stack (stage) contains a list of block arguments
|
||||||
|
for block_idx, ba in enumerate(stack_args):
|
||||||
|
if _DEBUG:
|
||||||
|
print('block', block_idx, end=', ')
|
||||||
|
if block_idx >= 1:
|
||||||
|
# only the first block in any stack/stage can have a stride > 1
|
||||||
|
ba['stride'] = 1
|
||||||
|
block = self._make_block(ba)
|
||||||
|
blocks.append(block)
|
||||||
|
return nn.Sequential(*blocks)
|
||||||
|
|
||||||
|
def __call__(self, in_chs, arch_def):
|
||||||
|
""" Build the blocks
|
||||||
|
Args:
|
||||||
|
in_chs: Number of input-channels passed to first block
|
||||||
|
arch_def: A list of lists, outer list defines stacks (or stages), inner
|
||||||
|
list contains strings defining block configuration(s)
|
||||||
|
Return:
|
||||||
|
List of block stacks (each stack wrapped in nn.Sequential)
|
||||||
|
"""
|
||||||
|
arch_args = _decode_arch_def(arch_def) # convert and expand string defs to arg dicts
|
||||||
|
if _DEBUG:
|
||||||
|
print('Building model trunk with %d stacks (stages)...' % len(arch_args))
|
||||||
|
self.in_chs = in_chs
|
||||||
|
blocks = []
|
||||||
|
# outer list of arch_args defines the stacks ('stages' by some conventions)
|
||||||
|
for stack_idx, stack in enumerate(arch_args):
|
||||||
|
if _DEBUG:
|
||||||
|
print('stack', stack_idx)
|
||||||
|
assert isinstance(stack, list)
|
||||||
|
stack = self._make_stack(stack)
|
||||||
|
blocks.append(stack)
|
||||||
|
if _DEBUG:
|
||||||
|
print()
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_weight(m):
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / n))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1.0)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
n = m.weight.size(1)
|
||||||
|
init_range = 1.0 / math.sqrt(n)
|
||||||
|
m.weight.data.uniform_(-init_range, init_range)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
class DepthwiseSeparableConv(nn.Module):
|
||||||
|
def __init__(self, in_chs, out_chs, kernel_size,
|
||||||
|
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
|
||||||
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
||||||
|
super(DepthwiseSeparableConv, self).__init__()
|
||||||
|
assert stride in [1, 2]
|
||||||
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||||
|
self.has_pw_act = pw_act # activation after point-wise conv
|
||||||
|
self.act_fn = act_fn
|
||||||
|
|
||||||
|
self.conv_dw = nn.Conv2d(
|
||||||
|
in_chs, in_chs, kernel_size,
|
||||||
|
stride=stride, padding=kernel_size // 2, groups=in_chs, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
self.conv_pw = nn.Conv2d(in_chs, out_chs, 1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
x = self.conv_dw(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act_fn(x)
|
||||||
|
x = self.conv_pw(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
if self.has_pw_act:
|
||||||
|
x = self.act_fn(x)
|
||||||
|
if self.has_residual:
|
||||||
|
x += residual
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CascadeConv3x3(nn.Sequential):
|
||||||
|
# FIXME lifted from maskrcnn_benchmark blocks, haven't used yet
|
||||||
|
def __init__(self, in_chs, out_chs, stride, act_fn=F.relu, noskip=False,
|
||||||
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
||||||
|
super(CascadeConv3x3, self).__init__()
|
||||||
|
assert stride in [1, 2]
|
||||||
|
self.has_residual = not noskip and (stride == 1 and in_chs == out_chs)
|
||||||
|
self.act_fn = act_fn
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_chs, in_chs, 3, stride=stride, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
self.conv2 = nn.Conv2d(in_chs, out_chs, 3, stride=1, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act_fn(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
if self.has_residual:
|
||||||
|
x += residual
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelShuffle(nn.Module):
|
||||||
|
# FIXME lifted from maskrcnn_benchmark blocks, haven't used yet
|
||||||
|
def __init__(self, groups):
|
||||||
|
super(ChannelShuffle, self).__init__()
|
||||||
|
self.groups = groups
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
g = self.groups
|
||||||
|
assert C % g == 0, "Incompatible group size {} for input channel {}".format(
|
||||||
|
g, C
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
x.view(N, g, int(C / g), H, W)
|
||||||
|
.permute(0, 2, 1, 3, 4)
|
||||||
|
.contiguous()
|
||||||
|
.view(N, C, H, W)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeExcite(nn.Module):
|
||||||
|
def __init__(self, in_chs, se_ratio=0.25, act_fn=F.relu):
|
||||||
|
super(SqueezeExcite, self).__init__()
|
||||||
|
self.act_fn = act_fn
|
||||||
|
reduced_chs = max(1, int(in_chs * se_ratio))
|
||||||
|
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||||
|
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance
|
||||||
|
x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
||||||
|
x_se = self.conv_reduce(x_se)
|
||||||
|
x_se = self.act_fn(x_se)
|
||||||
|
x_se = self.conv_expand(x_se)
|
||||||
|
x = torch.sigmoid(x_se) * x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidual(nn.Module):
|
||||||
|
""" Inverted residual block w/ optional SE"""
|
||||||
|
|
||||||
|
def __init__(self, in_chs, out_chs, kernel_size,
|
||||||
|
stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False,
|
||||||
|
se_ratio=0., shuffle_type=None, pw_group=1,
|
||||||
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
||||||
|
super(InvertedResidual, self).__init__()
|
||||||
|
mid_chs = int(in_chs * exp_ratio)
|
||||||
|
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||||
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||||
|
self.act_fn = act_fn
|
||||||
|
|
||||||
|
# Point-wise expansion
|
||||||
|
self.conv_pw = nn.Conv2d(in_chs, mid_chs, 1, groups=pw_group, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
|
||||||
|
self.shuffle_type = shuffle_type
|
||||||
|
if shuffle_type is not None:
|
||||||
|
self.shuffle = ChannelShuffle(pw_group)
|
||||||
|
|
||||||
|
# Depth-wise convolution
|
||||||
|
self.conv_dw = nn.Conv2d(
|
||||||
|
mid_chs, mid_chs, kernel_size, padding=kernel_size // 2,
|
||||||
|
stride=stride, groups=mid_chs, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
|
||||||
|
# Squeeze-and-excitation
|
||||||
|
if self.has_se:
|
||||||
|
self.se = SqueezeExcite(mid_chs, se_ratio)
|
||||||
|
|
||||||
|
# Point-wise linear projection
|
||||||
|
self.conv_pwl = nn.Conv2d(mid_chs, out_chs, 1, groups=pw_group, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
# Point-wise expansion
|
||||||
|
x = self.conv_pw(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act_fn(x)
|
||||||
|
|
||||||
|
# FIXME haven't tried this yet
|
||||||
|
# for channel shuffle when using groups with pointwise convs as per FBNet variants
|
||||||
|
if self.shuffle_type == "mid":
|
||||||
|
x = self.shuffle(x)
|
||||||
|
|
||||||
|
# Depth-wise convolution
|
||||||
|
x = self.conv_dw(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.act_fn(x)
|
||||||
|
|
||||||
|
# Squeeze-and-excitation
|
||||||
|
if self.has_se:
|
||||||
|
x = self.se(x)
|
||||||
|
|
||||||
|
# Point-wise linear projection
|
||||||
|
x = self.conv_pwl(x)
|
||||||
|
x = self.bn3(x)
|
||||||
|
|
||||||
|
if self.has_residual:
|
||||||
|
x += residual
|
||||||
|
|
||||||
|
# NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GenMobileNet(nn.Module):
|
||||||
|
""" Generic Mobile Net
|
||||||
|
|
||||||
|
An implementation of mobile optimized networks that covers:
|
||||||
|
* MobileNetV1
|
||||||
|
* MobileNetV2
|
||||||
|
* MNASNet A1, B1, and small
|
||||||
|
* FBNet A, B, and C
|
||||||
|
* ChamNet (arch details are murky)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
|
||||||
|
depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
||||||
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||||
|
drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False):
|
||||||
|
super(GenMobileNet, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.depth_multiplier = depth_multiplier
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
self.act_fn = act_fn
|
||||||
|
self.num_features = num_features
|
||||||
|
|
||||||
|
stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth)
|
||||||
|
self.conv_stem = nn.Conv2d(in_chans, stem_size, 3, padding=1, stride=2, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
in_chs = stem_size
|
||||||
|
|
||||||
|
builder = _BlockBuilder(
|
||||||
|
depth_multiplier, depth_divisor, min_depth,
|
||||||
|
bn_momentum, bn_eps)
|
||||||
|
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
||||||
|
in_chs = builder.in_chs
|
||||||
|
|
||||||
|
if skip_head_conv:
|
||||||
|
self.conv_head = None
|
||||||
|
assert in_chs == self.num_features
|
||||||
|
else:
|
||||||
|
self.conv_head = nn.Conv2d(in_chs, self.num_features, 1, padding=0, stride=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
|
||||||
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
|
self.classifier = nn.Linear(self.num_features, self.num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
_initialize_weight(m)
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.classifier
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
|
self.num_classes = num_classes
|
||||||
|
del self.classifier
|
||||||
|
if num_classes:
|
||||||
|
self.classifier = nn.Linear(
|
||||||
|
self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||||
|
else:
|
||||||
|
self.classifier = None
|
||||||
|
|
||||||
|
def forward_features(self, x, pool=True):
|
||||||
|
x = self.conv_stem(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act_fn(x)
|
||||||
|
x = self.blocks(x)
|
||||||
|
if self.conv_head is not None:
|
||||||
|
x = self.conv_head(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.act_fn(x)
|
||||||
|
if pool:
|
||||||
|
x = self.global_pool(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
if self.drop_rate > 0.:
|
||||||
|
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||||
|
return self.classifier(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_mnasnet_a1(depth_multiplier, num_classes=1000, **kwargs):
|
||||||
|
"""Creates a mnasnet-a1 model.
|
||||||
|
|
||||||
|
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
||||||
|
Paper: https://arxiv.org/pdf/1807.11626.pdf.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth_multiplier: multiplier to number of channels per layer.
|
||||||
|
"""
|
||||||
|
arch_def = [
|
||||||
|
# stage 0, 112x112 in
|
||||||
|
['ds_r1_k3_s1_e1_c16_noskip'],
|
||||||
|
# stage 1, 112x112 in
|
||||||
|
['ir_r2_k3_s2_e6_c24'],
|
||||||
|
# stage 2, 56x56 in
|
||||||
|
['ir_r3_k5_s2_e3_c40_se0.25'],
|
||||||
|
# stage 3, 28x28 in
|
||||||
|
['ir_r4_k3_s2_e6_c80'],
|
||||||
|
# stage 4, 14x14in
|
||||||
|
['ir_r2_k3_s1_e6_c112_se0.25'],
|
||||||
|
# stage 5, 14x14in
|
||||||
|
['ir_r3_k5_s2_e6_c160_se0.25'],
|
||||||
|
# stage 6, 7x7 in
|
||||||
|
['ir_r1_k3_s1_e6_c320'],
|
||||||
|
]
|
||||||
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
|
model = GenMobileNet(
|
||||||
|
arch_def,
|
||||||
|
num_classes=num_classes,
|
||||||
|
stem_size=32,
|
||||||
|
depth_multiplier=depth_multiplier,
|
||||||
|
depth_divisor=8,
|
||||||
|
min_depth=None,
|
||||||
|
bn_momentum=bn_momentum,
|
||||||
|
bn_eps=bn_eps,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_mnasnet_b1(depth_multiplier, num_classes=1000, **kwargs):
|
||||||
|
"""Creates a mnasnet-b1 model.
|
||||||
|
|
||||||
|
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
||||||
|
Paper: https://arxiv.org/pdf/1807.11626.pdf.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth_multiplier: multiplier to number of channels per layer.
|
||||||
|
"""
|
||||||
|
arch_def = [
|
||||||
|
# stage 0, 112x112 in
|
||||||
|
['ds_r1_k3_s1_c16_noskip'],
|
||||||
|
# stage 1, 112x112 in
|
||||||
|
['ir_r3_k3_s2_e3_c24'],
|
||||||
|
# stage 2, 56x56 in
|
||||||
|
['ir_r3_k5_s2_e3_c40'],
|
||||||
|
# stage 3, 28x28 in
|
||||||
|
['ir_r3_k5_s2_e6_c80'],
|
||||||
|
# stage 4, 14x14in
|
||||||
|
['ir_r2_k3_s1_e6_c96'],
|
||||||
|
# stage 5, 14x14in
|
||||||
|
['ir_r4_k5_s2_e6_c192'],
|
||||||
|
# stage 6, 7x7 in
|
||||||
|
['ir_r1_k3_s1_e6_c320_noskip']
|
||||||
|
]
|
||||||
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
|
model = GenMobileNet(
|
||||||
|
arch_def,
|
||||||
|
num_classes=num_classes,
|
||||||
|
stem_size=32,
|
||||||
|
depth_multiplier=depth_multiplier,
|
||||||
|
depth_divisor=8,
|
||||||
|
min_depth=None,
|
||||||
|
bn_momentum=bn_momentum,
|
||||||
|
bn_eps=bn_eps,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_mnasnet_small(depth_multiplier, num_classes=1000, **kwargs):
|
||||||
|
"""Creates a mnasnet-b1 model.
|
||||||
|
|
||||||
|
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
||||||
|
Paper: https://arxiv.org/pdf/1807.11626.pdf.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth_multiplier: multiplier to number of channels per layer.
|
||||||
|
"""
|
||||||
|
arch_def = [
|
||||||
|
['ds_r1_k3_s1_c8'],
|
||||||
|
['ir_r1_k3_s2_e3_c16'],
|
||||||
|
['ir_r2_k3_s2_e6_c16'],
|
||||||
|
['ir_r4_k5_s2_e6_c32_se0.25'],
|
||||||
|
['ir_r3_k3_s1_e6_c32_se0.25'],
|
||||||
|
['ir_r3_k5_s2_e6_c88_se0.25'],
|
||||||
|
['ir_r1_k3_s1_e6_c144']
|
||||||
|
]
|
||||||
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
|
model = GenMobileNet(
|
||||||
|
arch_def,
|
||||||
|
num_classes=num_classes,
|
||||||
|
stem_size=8,
|
||||||
|
depth_multiplier=depth_multiplier,
|
||||||
|
depth_divisor=8,
|
||||||
|
min_depth=None,
|
||||||
|
bn_momentum=bn_momentum,
|
||||||
|
bn_eps=bn_eps,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs):
|
||||||
|
"""
|
||||||
|
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
||||||
|
Paper: https://arxiv.org/abs/1801.04381
|
||||||
|
"""
|
||||||
|
arch_def = [
|
||||||
|
['dsa_r1_k3_s1_c64'],
|
||||||
|
['dsa_r2_k3_s2_c128'],
|
||||||
|
['dsa_r2_k3_s2_c256'],
|
||||||
|
['dsa_r6_k3_s2_c512'],
|
||||||
|
['dsa_r2_k3_s2_c1024'],
|
||||||
|
]
|
||||||
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
|
model = GenMobileNet(
|
||||||
|
arch_def,
|
||||||
|
num_classes=num_classes,
|
||||||
|
stem_size=32,
|
||||||
|
num_features=1024,
|
||||||
|
depth_multiplier=depth_multiplier,
|
||||||
|
depth_divisor=8,
|
||||||
|
min_depth=None,
|
||||||
|
bn_momentum=bn_momentum,
|
||||||
|
bn_eps=bn_eps,
|
||||||
|
act_fn=F.relu6,
|
||||||
|
skip_head_conv=True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_mobilenet_v2(depth_multiplier, num_classes=1000, **kwargs):
|
||||||
|
""" Generate MobileNet-V2 network
|
||||||
|
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
||||||
|
Paper: https://arxiv.org/abs/1801.04381
|
||||||
|
"""
|
||||||
|
arch_def = [
|
||||||
|
['ds_r1_k3_s1_c16'],
|
||||||
|
['ir_r2_k3_s2_e6_c24'],
|
||||||
|
['ir_r3_k3_s2_e6_c32'],
|
||||||
|
['ir_r4_k3_s2_e6_c64'],
|
||||||
|
['ir_r3_k3_s1_e6_c96'],
|
||||||
|
['ir_r3_k3_s2_e6_c160'],
|
||||||
|
['ir_r1_k3_s1_e6_c320'],
|
||||||
|
]
|
||||||
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
|
model = GenMobileNet(
|
||||||
|
arch_def,
|
||||||
|
num_classes=num_classes,
|
||||||
|
stem_size=32,
|
||||||
|
depth_multiplier=depth_multiplier,
|
||||||
|
depth_divisor=8,
|
||||||
|
min_depth=None,
|
||||||
|
bn_momentum=bn_momentum,
|
||||||
|
bn_eps=bn_eps,
|
||||||
|
act_fn=F.relu6,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_chamnet_v1(depth_multiplier, num_classes=1000, **kwargs):
|
||||||
|
""" Generate Chameleon Network (ChamNet)
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1812.08934
|
||||||
|
Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
|
||||||
|
|
||||||
|
FIXME: this a bit of an educated guess based on trunkd def in maskrcnn_benchmark
|
||||||
|
"""
|
||||||
|
arch_def = [
|
||||||
|
['ir_r1_k3_s1_e1_c24'],
|
||||||
|
['ir_r2_k7_s2_e4_c48'],
|
||||||
|
['ir_r5_k3_s2_e7_c64'],
|
||||||
|
['ir_r7_k5_s2_e12_c56'],
|
||||||
|
['ir_r5_k3_s1_e8_c88'],
|
||||||
|
['ir_r4_k3_s2_e7_c152'],
|
||||||
|
['ir_r1_k3_s1_e10_c104'],
|
||||||
|
]
|
||||||
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
|
model = GenMobileNet(
|
||||||
|
arch_def,
|
||||||
|
num_classes=num_classes,
|
||||||
|
stem_size=32,
|
||||||
|
num_features=1280, # no idea what this is? try mobile/mnasnet default?
|
||||||
|
depth_multiplier=depth_multiplier,
|
||||||
|
depth_divisor=8,
|
||||||
|
min_depth=None,
|
||||||
|
bn_momentum=bn_momentum,
|
||||||
|
bn_eps=bn_eps,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_chamnet_v2(depth_multiplier, num_classes=1000, **kwargs):
|
||||||
|
""" Generate Chameleon Network (ChamNet)
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1812.08934
|
||||||
|
Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
|
||||||
|
|
||||||
|
FIXME: this a bit of an educated guess based on trunk def in maskrcnn_benchmark
|
||||||
|
"""
|
||||||
|
arch_def = [
|
||||||
|
['ir_r1_k3_s1_e1_c24'],
|
||||||
|
['ir_r4_k5_s2_e8_c32'],
|
||||||
|
['ir_r6_k7_s2_e5_c48'],
|
||||||
|
['ir_r3_k5_s2_e9_c56'],
|
||||||
|
['ir_r6_k3_s1_e6_c56'],
|
||||||
|
['ir_r6_k3_s2_e2_c152'],
|
||||||
|
['ir_r1_k3_s1_e6_c112'],
|
||||||
|
]
|
||||||
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
|
model = GenMobileNet(
|
||||||
|
arch_def,
|
||||||
|
num_classes=num_classes,
|
||||||
|
stem_size=32,
|
||||||
|
num_features=1280, # no idea what this is? try mobile/mnasnet default?
|
||||||
|
depth_multiplier=depth_multiplier,
|
||||||
|
depth_divisor=8,
|
||||||
|
min_depth=None,
|
||||||
|
bn_momentum=bn_momentum,
|
||||||
|
bn_eps=bn_eps,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_fbnetc(depth_multiplier, num_classes=1000, **kwargs):
|
||||||
|
""" FBNet-C
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1812.03443
|
||||||
|
Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
|
||||||
|
|
||||||
|
NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper,
|
||||||
|
it was used to confirm some building block details
|
||||||
|
"""
|
||||||
|
arch_def = [
|
||||||
|
['ir_r1_k3_s1_e1_c16'],
|
||||||
|
['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'],
|
||||||
|
['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'],
|
||||||
|
['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'],
|
||||||
|
['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'],
|
||||||
|
['ir_r4_k5_s2_e6_c184'],
|
||||||
|
['ir_r1_k3_s1_e6_c352'],
|
||||||
|
]
|
||||||
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
|
model = GenMobileNet(
|
||||||
|
arch_def,
|
||||||
|
num_classes=num_classes,
|
||||||
|
stem_size=16,
|
||||||
|
num_features=1984, # paper suggests this, but is not 100% clear
|
||||||
|
depth_multiplier=depth_multiplier,
|
||||||
|
depth_divisor=8,
|
||||||
|
min_depth=None,
|
||||||
|
bn_momentum=bn_momentum,
|
||||||
|
bn_eps=bn_eps,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def mnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MNASNet B1, depth multiplier of 0.5. """
|
||||||
|
default_cfg = default_cfgs['mnasnet0_50']
|
||||||
|
model = _gen_mnasnet_b1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def mnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MNASNet B1, depth multiplier of 0.75. """
|
||||||
|
default_cfg = default_cfgs['mnasnet0_75']
|
||||||
|
model = _gen_mnasnet_b1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MNASNet B1, depth multiplier of 1.0. """
|
||||||
|
default_cfg = default_cfgs['mnasnet1_00']
|
||||||
|
model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def mnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MNASNet B1, depth multiplier of 1.4 """
|
||||||
|
default_cfg = default_cfgs['mnasnet1_40']
|
||||||
|
model = _gen_mnasnet_b1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def semnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MNASNet A1 (w/ SE), depth multiplier of 0.5 """
|
||||||
|
default_cfg = default_cfgs['semnasnet0_50']
|
||||||
|
model = _gen_mnasnet_a1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def semnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MNASNet A1 (w/ SE), depth multiplier of 0.75. """
|
||||||
|
default_cfg = default_cfgs['semnasnet0_75']
|
||||||
|
model = _gen_mnasnet_a1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
|
||||||
|
default_cfg = default_cfgs['semnasnet1_00']
|
||||||
|
model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def semnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
|
||||||
|
default_cfg = default_cfgs['semnasnet1_40']
|
||||||
|
model = _gen_mnasnet_a1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def mnasnet_small(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MNASNet Small, depth multiplier of 1.0. """
|
||||||
|
default_cfg = default_cfgs['mnasnet_small']
|
||||||
|
model = _gen_mnasnet_small(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def mobilenetv1_1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V1 """
|
||||||
|
default_cfg = default_cfgs['mobilenetv1_1_00']
|
||||||
|
model = _gen_mobilenet_v1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def mobilenetv2_1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V2 """
|
||||||
|
default_cfg = default_cfgs['mobilenetv2_1_00']
|
||||||
|
model = _gen_mobilenet_v2(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def fbnetc_1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" FBNet-C """
|
||||||
|
default_cfg = default_cfgs['fbnetc_1_00']
|
||||||
|
model = _gen_fbnetc(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def chamnetv1_1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" ChamNet """
|
||||||
|
default_cfg = default_cfgs['chamnetv1_1_00']
|
||||||
|
model = _gen_chamnet_v1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def chamnetv2_1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" ChamNet """
|
||||||
|
default_cfg = default_cfgs['chamnetv2_1_00']
|
||||||
|
model = _gen_chamnet_v2(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
@ -1,462 +0,0 @@
|
|||||||
""" MNASNet (a1, b1, and small)
|
|
||||||
|
|
||||||
Based on offical TF implementation w/ round_channels,
|
|
||||||
decode_block_str, and model block args directly transferred
|
|
||||||
https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
|
||||||
|
|
||||||
Original paper: https://arxiv.org/pdf/1807.11626.pdf.
|
|
||||||
|
|
||||||
Hacked together by Ross Wightman
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import re
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from models.helpers import load_pretrained
|
|
||||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
|
||||||
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
||||||
|
|
||||||
__all__ = ['MnasNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40',
|
|
||||||
'semnasnet0_50', 'semnasnet0_75', 'semnasnet1_00', 'semnasnet1_40',
|
|
||||||
'mnasnet_small']
|
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
|
||||||
return {
|
|
||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
||||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
||||||
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
default_cfgs = {
|
|
||||||
'mnasnet0_50': _cfg(url=''),
|
|
||||||
'mnasnet0_75': _cfg(url=''),
|
|
||||||
'mnasnet1_00': _cfg(url=''),
|
|
||||||
'mnasnet1_40': _cfg(url=''),
|
|
||||||
'semnasnet0_50': _cfg(url=''),
|
|
||||||
'semnasnet0_75': _cfg(url=''),
|
|
||||||
'semnasnet1_00': _cfg(url=''),
|
|
||||||
'semnasnet1_40': _cfg(url=''),
|
|
||||||
'mnasnet_small': _cfg(url=''),
|
|
||||||
}
|
|
||||||
|
|
||||||
_BN_MOMENTUM_DEFAULT = 1 - 0.99
|
|
||||||
_BN_EPS_DEFAULT = 1e-3
|
|
||||||
|
|
||||||
|
|
||||||
def _round_channels(channels, depth_multiplier=1.0, depth_divisor=8, min_depth=None):
|
|
||||||
"""Round number of filters based on depth multiplier."""
|
|
||||||
multiplier = depth_multiplier
|
|
||||||
divisor = depth_divisor
|
|
||||||
min_depth = min_depth
|
|
||||||
if not multiplier:
|
|
||||||
return channels
|
|
||||||
|
|
||||||
channels *= multiplier
|
|
||||||
min_depth = min_depth or divisor
|
|
||||||
new_channels = max(min_depth, int(channels + divisor / 2) // divisor * divisor)
|
|
||||||
# Make sure that round down does not go down by more than 10%.
|
|
||||||
if new_channels < 0.9 * channels:
|
|
||||||
new_channels += divisor
|
|
||||||
return new_channels
|
|
||||||
|
|
||||||
|
|
||||||
def _decode_block_str(block_str):
|
|
||||||
"""Gets a MNasNet block through a string notation of arguments.
|
|
||||||
E.g. r2_k3_s2_e1_i32_o16_se0.25_noskip:
|
|
||||||
r - number of repeat blocks,
|
|
||||||
k - kernel size,
|
|
||||||
s - strides (1-9),
|
|
||||||
e - expansion ratio,
|
|
||||||
i - input filters,
|
|
||||||
o - output filters,
|
|
||||||
se - squeeze/excitation ratio
|
|
||||||
Args:
|
|
||||||
block_string: a string, a string representation of block arguments.
|
|
||||||
Returns:
|
|
||||||
A BlockArgs instance.
|
|
||||||
Raises:
|
|
||||||
ValueError: if the strides option is not correctly specified.
|
|
||||||
"""
|
|
||||||
assert isinstance(block_str, str)
|
|
||||||
ops = block_str.split('_')
|
|
||||||
options = {}
|
|
||||||
for op in ops:
|
|
||||||
splits = re.split(r'(\d.*)', op)
|
|
||||||
if len(splits) >= 2:
|
|
||||||
key, value = splits[:2]
|
|
||||||
options[key] = value
|
|
||||||
|
|
||||||
if 's' not in options or len(options['s']) != 2:
|
|
||||||
raise ValueError('Strides options should be a pair of integers.')
|
|
||||||
|
|
||||||
return dict(
|
|
||||||
kernel_size=int(options['k']),
|
|
||||||
num_repeat=int(options['r']),
|
|
||||||
in_chs=int(options['i']),
|
|
||||||
out_chs=int(options['o']),
|
|
||||||
exp_ratio=int(options['e']),
|
|
||||||
id_skip=('noskip' not in block_str),
|
|
||||||
se_ratio=float(options['se']) if 'se' in options else None,
|
|
||||||
stride=int(options['s'][0]) # TF impl passes a list of two strides
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _decode_block_args(string_list):
|
|
||||||
block_args = []
|
|
||||||
for block_str in string_list:
|
|
||||||
block_args.append(_decode_block_str(block_str))
|
|
||||||
return block_args
|
|
||||||
|
|
||||||
|
|
||||||
def _initialize_weight(m):
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
||||||
m.weight.data.normal_(0, math.sqrt(2.0 / n))
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
|
||||||
m.weight.data.fill_(1.0)
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
n = m.weight.size(1)
|
|
||||||
init_range = 1.0 / math.sqrt(n)
|
|
||||||
m.weight.data.uniform_(-init_range, init_range)
|
|
||||||
m.bias.data.zero_()
|
|
||||||
|
|
||||||
|
|
||||||
class MnasBlock(nn.Module):
|
|
||||||
""" MNASNet Inverted residual block w/ optional SE"""
|
|
||||||
|
|
||||||
def __init__(self, in_chs, out_chs, kernel_size, stride,
|
|
||||||
exp_ratio=1.0, id_skip=True, se_ratio=0.,
|
|
||||||
bn_momentum=0.1, bn_eps=1e-3, act_fn=F.relu):
|
|
||||||
super(MnasBlock, self).__init__()
|
|
||||||
exp_chs = int(in_chs * exp_ratio)
|
|
||||||
self.has_expansion = exp_ratio != 1
|
|
||||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
|
||||||
self.has_residual = id_skip and (in_chs == out_chs and stride == 1)
|
|
||||||
self.act_fn = act_fn
|
|
||||||
|
|
||||||
# Pointwise expansion
|
|
||||||
if self.has_expansion:
|
|
||||||
self.conv_expand = nn.Conv2d(in_chs, exp_chs, 1, bias=False)
|
|
||||||
self.bn0 = nn.BatchNorm2d(exp_chs, momentum=bn_momentum, eps=bn_eps)
|
|
||||||
|
|
||||||
# Depthwise convolution
|
|
||||||
self.conv_depthwise = nn.Conv2d(
|
|
||||||
exp_chs, exp_chs, kernel_size, padding=kernel_size // 2,
|
|
||||||
stride=stride, groups=exp_chs, bias=False)
|
|
||||||
self.bn1 = nn.BatchNorm2d(exp_chs, momentum=bn_momentum, eps=bn_eps)
|
|
||||||
|
|
||||||
# Squeeze-and-excitation
|
|
||||||
if self.has_se:
|
|
||||||
num_reduced_ch = max(1, int(in_chs * se_ratio))
|
|
||||||
self.conv_se_reduce = nn.Conv2d(exp_chs, num_reduced_ch, 1, bias=True)
|
|
||||||
self.conv_se_expand = nn.Conv2d(num_reduced_ch, exp_chs, 1, bias=True)
|
|
||||||
|
|
||||||
# Pointwise projection
|
|
||||||
self.conv_project = nn.Conv2d(exp_chs, out_chs, 1, bias=False)
|
|
||||||
self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
residual = x
|
|
||||||
# Pointwise expansion
|
|
||||||
if self.has_expansion:
|
|
||||||
x = self.conv_expand(x)
|
|
||||||
x = self.bn0(x)
|
|
||||||
x = self.act_fn(x)
|
|
||||||
# Depthwise convolution
|
|
||||||
x = self.conv_depthwise(x)
|
|
||||||
x = self.bn1(x)
|
|
||||||
x = self.act_fn(x)
|
|
||||||
# Squeeze-and-excitation
|
|
||||||
if self.has_se:
|
|
||||||
x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
|
||||||
x_se = self.conv_se_reduce(x_se)
|
|
||||||
x_se = F.relu(x_se)
|
|
||||||
x_se = self.conv_se_expand(x_se)
|
|
||||||
x = F.sigmoid(x_se) * x
|
|
||||||
# Pointwise projection
|
|
||||||
x = self.conv_project(x)
|
|
||||||
x = self.bn2(x)
|
|
||||||
if self.has_residual:
|
|
||||||
return x + residual
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MnasNet(nn.Module):
|
|
||||||
""" MNASNet
|
|
||||||
|
|
||||||
Based on offical TF implementation
|
|
||||||
https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
|
||||||
|
|
||||||
Original paper: https://arxiv.org/pdf/1807.11626.pdf.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32,
|
|
||||||
depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
|
||||||
bn_momentum=_BN_MOMENTUM_DEFAULT, bn_eps=_BN_EPS_DEFAULT, drop_rate=0.,
|
|
||||||
global_pool='avg', act_fn=F.relu):
|
|
||||||
super(MnasNet, self).__init__()
|
|
||||||
self.num_classes = num_classes
|
|
||||||
self.depth_multiplier = depth_multiplier
|
|
||||||
self.bn_momentum = bn_momentum
|
|
||||||
self.bn_eps = bn_eps
|
|
||||||
self.drop_rate = drop_rate
|
|
||||||
self.act_fn = act_fn
|
|
||||||
self.num_features = 1280
|
|
||||||
|
|
||||||
self.conv_stem = nn.Conv2d(in_chans, stem_size, 3, padding=1, stride=2, bias=False)
|
|
||||||
self.bn0 = nn.BatchNorm2d(stem_size, momentum=self.bn_momentum, eps=self.bn_eps)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
for i, a in enumerate(block_args):
|
|
||||||
print(a) #FIXME debug
|
|
||||||
a['in_chs'] = _round_channels(a['in_chs'], depth_multiplier, depth_divisor, min_depth)
|
|
||||||
a['out_chs'] = _round_channels(a['out_chs'], depth_multiplier, depth_divisor, min_depth)
|
|
||||||
blocks.append(self._make_stack(**a))
|
|
||||||
out_chs = a['out_chs']
|
|
||||||
self.blocks = nn.Sequential(*blocks)
|
|
||||||
|
|
||||||
self.conv_head = nn.Conv2d(out_chs, self.num_features, 1, padding=0, stride=1, bias=False)
|
|
||||||
self.bn1 = nn.BatchNorm2d(self.num_features, momentum=self.bn_momentum, eps=self.bn_eps)
|
|
||||||
|
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
||||||
self.classifier = nn.Linear(self.num_features, self.num_classes)
|
|
||||||
|
|
||||||
for m in self.modules():
|
|
||||||
_initialize_weight(m)
|
|
||||||
|
|
||||||
def _make_stack(self, in_chs, out_chs, kernel_size, stride,
|
|
||||||
exp_ratio, id_skip, se_ratio, num_repeat):
|
|
||||||
blocks = [MnasBlock(
|
|
||||||
in_chs, out_chs, kernel_size, stride, exp_ratio, id_skip, se_ratio,
|
|
||||||
bn_momentum=self.bn_momentum, bn_eps=self.bn_eps)]
|
|
||||||
for _ in range(1, num_repeat):
|
|
||||||
blocks += [MnasBlock(
|
|
||||||
out_chs, out_chs, kernel_size, 1, exp_ratio, id_skip, se_ratio,
|
|
||||||
bn_momentum=self.bn_momentum, bn_eps=self.bn_eps)]
|
|
||||||
return nn.Sequential(*blocks)
|
|
||||||
|
|
||||||
def get_classifier(self):
|
|
||||||
return self.classifier
|
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
||||||
self.num_classes = num_classes
|
|
||||||
del self.classifier
|
|
||||||
if num_classes:
|
|
||||||
self.classifier = nn.Linear(
|
|
||||||
self.num_features * self.global_pool.feat_mult(), num_classes)
|
|
||||||
else:
|
|
||||||
self.classifier = None
|
|
||||||
|
|
||||||
def forward_features(self, x, pool=True):
|
|
||||||
x = self.conv_stem(x)
|
|
||||||
x = self.bn0(x)
|
|
||||||
x = self.act_fn(x)
|
|
||||||
x = self.blocks(x)
|
|
||||||
x = self.conv_head(x)
|
|
||||||
x = self.bn1(x)
|
|
||||||
x = self.act_fn(x)
|
|
||||||
if pool:
|
|
||||||
x = self.global_pool(x)
|
|
||||||
x = x.view(x.size(0), -1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.forward_features(x)
|
|
||||||
if self.drop_rate > 0.:
|
|
||||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
|
||||||
return self.classifier(x)
|
|
||||||
|
|
||||||
|
|
||||||
def mnasnet_a1(depth_multiplier, num_classes=1000, **kwargs):
|
|
||||||
"""Creates a mnasnet-a1 model.
|
|
||||||
Args:
|
|
||||||
depth_multiplier: multiplier to number of channels per layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# defs from https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
|
|
||||||
block_defs = [
|
|
||||||
'r1_k3_s11_e1_i32_o16_noskip',
|
|
||||||
'r2_k3_s22_e6_i16_o24',
|
|
||||||
'r3_k5_s22_e3_i24_o40_se0.25',
|
|
||||||
'r4_k3_s22_e6_i40_o80',
|
|
||||||
'r2_k3_s11_e6_i80_o112_se0.25',
|
|
||||||
'r3_k5_s22_e6_i112_o160_se0.25',
|
|
||||||
'r1_k3_s11_e6_i160_o320'
|
|
||||||
]
|
|
||||||
block_args = _decode_block_args(block_defs)
|
|
||||||
model = MnasNet(
|
|
||||||
block_args,
|
|
||||||
num_classes=num_classes,
|
|
||||||
depth_multiplier=depth_multiplier,
|
|
||||||
depth_divisor=8,
|
|
||||||
min_depth=None,
|
|
||||||
stem_size=32,
|
|
||||||
bn_momentum=_BN_MOMENTUM_DEFAULT,
|
|
||||||
bn_eps=_BN_EPS_DEFAULT,
|
|
||||||
#drop_rate=0.2,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def mnasnet_b1(depth_multiplier, num_classes=1000, **kwargs):
|
|
||||||
"""Creates a mnasnet-b1 model.
|
|
||||||
Args:
|
|
||||||
depth_multiplier: multiplier to number of channels per layer.
|
|
||||||
"""
|
|
||||||
# from https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
|
|
||||||
blocks_defs = [
|
|
||||||
'r1_k3_s11_e1_i32_o16_noskip',
|
|
||||||
'r3_k3_s22_e3_i16_o24',
|
|
||||||
'r3_k5_s22_e3_i24_o40',
|
|
||||||
'r3_k5_s22_e6_i40_o80',
|
|
||||||
'r2_k3_s11_e6_i80_o96',
|
|
||||||
'r4_k5_s22_e6_i96_o192',
|
|
||||||
'r1_k3_s11_e6_i192_o320_noskip'
|
|
||||||
]
|
|
||||||
block_args = _decode_block_args(blocks_defs)
|
|
||||||
model = MnasNet(
|
|
||||||
block_args,
|
|
||||||
num_classes=num_classes,
|
|
||||||
depth_multiplier=depth_multiplier,
|
|
||||||
depth_divisor=8,
|
|
||||||
min_depth=None,
|
|
||||||
stem_size=32,
|
|
||||||
bn_momentum=_BN_MOMENTUM_DEFAULT,
|
|
||||||
bn_eps=_BN_EPS_DEFAULT,
|
|
||||||
#drop_rate=0.2,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def mnasnet_small(depth_multiplier, num_classes=1000, **kwargs):
|
|
||||||
"""Creates a mnasnet-b1 model.
|
|
||||||
Args:
|
|
||||||
depth_multiplier: multiplier to number of channels per layer.
|
|
||||||
"""
|
|
||||||
# from https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
|
|
||||||
blocks_defs = [
|
|
||||||
'r1_k3_s11_e1_i16_o8',
|
|
||||||
'r1_k3_s22_e3_i8_o16',
|
|
||||||
'r2_k3_s22_e6_i16_o16',
|
|
||||||
'r4_k5_s22_e6_i16_o32_se0.25',
|
|
||||||
'r3_k3_s11_e6_i32_o32_se0.25',
|
|
||||||
'r3_k5_s22_e6_i32_o88_se0.25',
|
|
||||||
'r1_k3_s11_e6_i88_o144'
|
|
||||||
]
|
|
||||||
block_args = _decode_block_args(blocks_defs)
|
|
||||||
model = MnasNet(
|
|
||||||
block_args,
|
|
||||||
num_classes=num_classes,
|
|
||||||
depth_multiplier=depth_multiplier,
|
|
||||||
depth_divisor=8,
|
|
||||||
min_depth=None,
|
|
||||||
stem_size=8,
|
|
||||||
bn_momentum=_BN_MOMENTUM_DEFAULT,
|
|
||||||
bn_eps=_BN_EPS_DEFAULT,
|
|
||||||
#drop_rate=0.2,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def mnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
|
||||||
""" MNASNet B1, depth multiplier of 0.5. """
|
|
||||||
default_cfg = default_cfgs['mnasnet0_50']
|
|
||||||
model = mnasnet_b1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def mnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
||||||
""" MNASNet B1, depth multiplier of 0.75. """
|
|
||||||
default_cfg = default_cfgs['mnasnet0_50']
|
|
||||||
model = mnasnet_b1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
||||||
""" MNASNet B1, depth multiplier of 1.0. """
|
|
||||||
default_cfg = default_cfgs['mnasnet0_50']
|
|
||||||
model = mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def mnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
||||||
""" MNASNet B1, depth multiplier of 1.4 """
|
|
||||||
default_cfg = default_cfgs['mnasnet0_50']
|
|
||||||
model = mnasnet_b1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def semnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
|
||||||
""" MNASNet A1 (w/ SE), depth multiplier of 0.5 """
|
|
||||||
default_cfg = default_cfgs['mnasnet0_50']
|
|
||||||
model = mnasnet_a1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def semnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
||||||
""" MNASNet A1 (w/ SE), depth multiplier of 0.75. """
|
|
||||||
default_cfg = default_cfgs['mnasnet0_50']
|
|
||||||
model = mnasnet_a1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
||||||
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
|
|
||||||
default_cfg = default_cfgs['mnasnet0_50']
|
|
||||||
model = mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def semnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
||||||
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
|
|
||||||
default_cfg = default_cfgs['mnasnet0_50']
|
|
||||||
model = mnasnet_a1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def mnasnet_small(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
||||||
""" MNASNet Small, depth multiplier of 1.0. """
|
|
||||||
default_cfg = default_cfgs['mnasnet_small']
|
|
||||||
model = mnasnet_small(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
@ -1,3 +1,4 @@
|
|||||||
from optim.adabound import AdaBound
|
from optim.adabound import AdaBound
|
||||||
from optim.nadam import Nadam
|
from optim.nadam import Nadam
|
||||||
|
from optim.rmsprop_tf import RMSpropTF
|
||||||
from optim.optim_factory import create_optimizer
|
from optim.optim_factory import create_optimizer
|
@ -0,0 +1,105 @@
|
|||||||
|
import torch
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class RMSpropTF(Optimizer):
|
||||||
|
"""Implements RMSprop algorithm (TensorFlow style epsilon)
|
||||||
|
|
||||||
|
NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
|
||||||
|
to closer match Tensorflow for matching hyper-params.
|
||||||
|
|
||||||
|
Proposed by G. Hinton in his
|
||||||
|
`course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
|
||||||
|
|
||||||
|
The centered version first appears in `Generating Sequences
|
||||||
|
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
params (iterable): iterable of parameters to optimize or dicts defining
|
||||||
|
parameter groups
|
||||||
|
lr (float, optional): learning rate (default: 1e-2)
|
||||||
|
momentum (float, optional): momentum factor (default: 0)
|
||||||
|
alpha (float, optional): smoothing constant (default: 0.99)
|
||||||
|
eps (float, optional): term added to the denominator to improve
|
||||||
|
numerical stability (default: 1e-8)
|
||||||
|
centered (bool, optional) : if ``True``, compute the centered RMSProp,
|
||||||
|
the gradient is normalized by an estimation of its variance
|
||||||
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
|
if not 0.0 <= momentum:
|
||||||
|
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||||
|
if not 0.0 <= weight_decay:
|
||||||
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||||
|
if not 0.0 <= alpha:
|
||||||
|
raise ValueError("Invalid alpha value: {}".format(alpha))
|
||||||
|
|
||||||
|
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
|
||||||
|
super(RMSpropTF, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super(RMSpropTF, self).__setstate__(state)
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault('momentum', 0)
|
||||||
|
group.setdefault('centered', False)
|
||||||
|
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad.data
|
||||||
|
if grad.is_sparse:
|
||||||
|
raise RuntimeError('RMSprop does not support sparse gradients')
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
# State initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
state['step'] = 0
|
||||||
|
state['square_avg'] = torch.zeros_like(p.data)
|
||||||
|
if group['momentum'] > 0:
|
||||||
|
state['momentum_buffer'] = torch.zeros_like(p.data)
|
||||||
|
if group['centered']:
|
||||||
|
state['grad_avg'] = torch.zeros_like(p.data)
|
||||||
|
|
||||||
|
square_avg = state['square_avg']
|
||||||
|
alpha = group['alpha']
|
||||||
|
|
||||||
|
state['step'] += 1
|
||||||
|
|
||||||
|
if group['weight_decay'] != 0:
|
||||||
|
grad = grad.add(group['weight_decay'], p.data)
|
||||||
|
|
||||||
|
square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)
|
||||||
|
|
||||||
|
if group['centered']:
|
||||||
|
grad_avg = state['grad_avg']
|
||||||
|
grad_avg.mul_(alpha).add_(1 - alpha, grad)
|
||||||
|
avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_()
|
||||||
|
else:
|
||||||
|
avg = square_avg.add(group['eps']).sqrt_()
|
||||||
|
|
||||||
|
if group['momentum'] > 0:
|
||||||
|
buf = state['momentum_buffer']
|
||||||
|
buf.mul_(group['momentum']).addcdiv_(grad, avg)
|
||||||
|
p.data.add_(-group['lr'], buf)
|
||||||
|
else:
|
||||||
|
p.data.addcdiv_(-group['lr'], grad, avg)
|
||||||
|
|
||||||
|
return loss
|
Loading…
Reference in new issue