diff --git a/models/genmobilenet.py b/models/genmobilenet.py new file mode 100644 index 00000000..756189b1 --- /dev/null +++ b/models/genmobilenet.py @@ -0,0 +1,957 @@ +""" Generic MobileNet + +A generic MobileNet class with building blocks to support a variety of models: +* MNasNet B1, A1 (SE), Small +* MobileNetV2 +* FBNet-C (TODO A & B) +* ChamNet (TODO still guessing at architecture definition) +* ShuffleNetV2 (TODO add IR shuffle block) +* And likely more... + +TODO not all combinations and variations have been tested. Currently working on training hyper-params... + +Hacked together by Ross Wightman +""" + +import math +import re +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.helpers import load_pretrained +from models.adaptive_avgmax_pool import SelectAdaptivePool2d +from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + +__all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40', + 'semnasnet0_50', 'semnasnet0_75', 'semnasnet1_00', 'semnasnet1_40', + 'mnasnet_small'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'mnasnet0_50': _cfg(url=''), + 'mnasnet0_75': _cfg(url=''), + 'mnasnet1_00': _cfg(url=''), + 'mnasnet1_40': _cfg(url=''), + 'semnasnet0_50': _cfg(url=''), + 'semnasnet0_75': _cfg(url=''), + 'semnasnet1_00': _cfg(url=''), + 'semnasnet1_40': _cfg(url=''), + 'mnasnet_small': _cfg(url=''), + 'mobilenetv1_1_00': _cfg(url=''), + 'mobilenetv2_1_00': _cfg(url=''), + 'chamnetv1_1_00': _cfg(url=''), + 'chamnetv2_1_00': _cfg(url=''), + 'fbnetc_1_00': _cfg(url=''), +} + +_DEBUG = True + +# default args for PyTorch BN impl +_BN_MOMENTUM_PT_DEFAULT = 0.1 +_BN_EPS_PT_DEFAULT = 1e-5 + +# defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 # NOTE this varies, .99 or .9997 depending on ref +_BN_EPS_TF_DEFAULT = 1e-3 + + +def _resolve_bn_params(kwargs): + # NOTE kwargs passed as dict intentionally + bn_momentum_default = _BN_MOMENTUM_PT_DEFAULT + bn_eps_default = _BN_EPS_PT_DEFAULT + bn_tf = kwargs.pop('bn_tf', False) + if bn_tf: + bn_momentum_default = _BN_MOMENTUM_TF_DEFAULT + bn_eps_default = _BN_EPS_TF_DEFAULT + bn_momentum = kwargs.pop('bn_momentum', None) + bn_eps = kwargs.pop('bn_eps', None) + if bn_momentum is None: + bn_momentum = bn_momentum_default + if bn_eps is None: + bn_eps = bn_eps_default + return bn_momentum, bn_eps + + +def _round_channels(channels, depth_multiplier=1.0, depth_divisor=8, min_depth=None): + """Round number of filters based on depth multiplier.""" + if not depth_multiplier: + return channels + + channels *= depth_multiplier + min_depth = min_depth or depth_divisor + new_channels = max( + int(channels + depth_divisor / 2) // depth_divisor * depth_divisor, + min_depth) + # Make sure that round down does not go down by more than 10%. + if new_channels < 0.9 * channels: + new_channels += depth_divisor + return new_channels + + +def _decode_block_str(block_str): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, + ca = Cascade3x3, and possibly more) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # FIXME validate args and throw + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + exp_ratio=int(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + noskip=('noskip' in block_str), + ) + if 'g' in options: + block_args['pw_group'] = options['g'] + if options['g'] > 1: + block_args['shuffle_type'] = 'mid' + elif block_type == 'ca': + block_args = dict( + block_type=block_type, + out_chs=int(options['c']), + stride=int(options['s']), + noskip=('noskip' in block_str), + ) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + noskip=block_type == 'dsa' or 'noskip' in block_str, + pw_act=block_type == 'dsa', + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + # return a list of block args expanded by num_repeat + return [deepcopy(block_args) for _ in range(num_repeat)] + + +def _get_padding(kernel_size, stride, dilation): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def _decode_arch_args(string_list): + block_args = [] + for block_str in string_list: + block_args.append(_decode_block_str(block_str)) + return block_args + + +def _decode_arch_def(arch_def): + arch_args = [] + for stack_idx, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + for block_str in block_strings: + assert isinstance(block_str, str) + stack_args.extend(_decode_block_str(block_str)) + arch_args.append(stack_args) + return arch_args + + +class _BlockBuilder: + """ Build Trunk Blocks + + This ended up being somewhat of a cross between + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py + and + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py + + """ + + def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None, + bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT): + self.depth_multiplier = depth_multiplier + self.depth_divisor = depth_divisor + self.min_depth = min_depth + self.bn_momentum = bn_momentum + self.bn_eps = bn_eps + self.in_chs = None + + def _round_channels(self, chs): + return _round_channels(chs, self.depth_multiplier, self.depth_divisor, self.min_depth) + + def _make_block(self, ba): + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = _round_channels(ba['out_chs']) + ba['bn_momentum'] = self.bn_momentum + ba['bn_eps'] = self.bn_eps + if _DEBUG: + print('args:', ba) + # could replace this with lambdas or functools binding if variety increases + if bt == 'ir': + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + block = DepthwiseSeparableConv(**ba) + elif bt == 'ca': + block = CascadeConv3x3(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + return block + + def _make_stack(self, stack_args): + blocks = [] + # each stack (stage) contains a list of block arguments + for block_idx, ba in enumerate(stack_args): + if _DEBUG: + print('block', block_idx, end=', ') + if block_idx >= 1: + # only the first block in any stack/stage can have a stride > 1 + ba['stride'] = 1 + block = self._make_block(ba) + blocks.append(block) + return nn.Sequential(*blocks) + + def __call__(self, in_chs, arch_def): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + arch_def: A list of lists, outer list defines stacks (or stages), inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + arch_args = _decode_arch_def(arch_def) # convert and expand string defs to arg dicts + if _DEBUG: + print('Building model trunk with %d stacks (stages)...' % len(arch_args)) + self.in_chs = in_chs + blocks = [] + # outer list of arch_args defines the stacks ('stages' by some conventions) + for stack_idx, stack in enumerate(arch_args): + if _DEBUG: + print('stack', stack_idx) + assert isinstance(stack, list) + stack = self._make_stack(stack) + blocks.append(stack) + if _DEBUG: + print() + return blocks + + +def _initialize_weight(m): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + init_range = 1.0 / math.sqrt(n) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +class DepthwiseSeparableConv(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, act_fn=F.relu, noskip=False, pw_act=False, + bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT): + super(DepthwiseSeparableConv, self).__init__() + assert stride in [1, 2] + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.has_pw_act = pw_act # activation after point-wise conv + self.act_fn = act_fn + + self.conv_dw = nn.Conv2d( + in_chs, in_chs, kernel_size, + stride=stride, padding=kernel_size // 2, groups=in_chs, bias=False) + self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps) + self.conv_pw = nn.Conv2d(in_chs, out_chs, 1, bias=False) + self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) + + def forward(self, x): + residual = x + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act_fn(x) + x = self.conv_pw(x) + x = self.bn2(x) + if self.has_pw_act: + x = self.act_fn(x) + if self.has_residual: + x += residual + return x + + +class CascadeConv3x3(nn.Sequential): + # FIXME lifted from maskrcnn_benchmark blocks, haven't used yet + def __init__(self, in_chs, out_chs, stride, act_fn=F.relu, noskip=False, + bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT): + super(CascadeConv3x3, self).__init__() + assert stride in [1, 2] + self.has_residual = not noskip and (stride == 1 and in_chs == out_chs) + self.act_fn = act_fn + + self.conv1 = nn.Conv2d(in_chs, in_chs, 3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps) + self.conv2 = nn.Conv2d(in_chs, out_chs, 3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.bn1(x) + x = self.act_fn(x) + x = self.conv2(x) + x = self.bn2(x) + if self.has_residual: + x += residual + return x + + +class ChannelShuffle(nn.Module): + # FIXME lifted from maskrcnn_benchmark blocks, haven't used yet + def __init__(self, groups): + super(ChannelShuffle, self).__init__() + self.groups = groups + + def forward(self, x): + """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" + N, C, H, W = x.size() + g = self.groups + assert C % g == 0, "Incompatible group size {} for input channel {}".format( + g, C + ) + return ( + x.view(N, g, int(C / g), H, W) + .permute(0, 2, 1, 3, 4) + .contiguous() + .view(N, C, H, W) + ) + + +class SqueezeExcite(nn.Module): + def __init__(self, in_chs, se_ratio=0.25, act_fn=F.relu): + super(SqueezeExcite, self).__init__() + self.act_fn = act_fn + reduced_chs = max(1, int(in_chs * se_ratio)) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + # NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance + x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) + x_se = self.conv_reduce(x_se) + x_se = self.act_fn(x_se) + x_se = self.conv_expand(x_se) + x = torch.sigmoid(x_se) * x + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE""" + + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False, + se_ratio=0., shuffle_type=None, pw_group=1, + bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT): + super(InvertedResidual, self).__init__() + mid_chs = int(in_chs * exp_ratio) + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.act_fn = act_fn + + # Point-wise expansion + self.conv_pw = nn.Conv2d(in_chs, mid_chs, 1, groups=pw_group, bias=False) + self.bn1 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps) + + self.shuffle_type = shuffle_type + if shuffle_type is not None: + self.shuffle = ChannelShuffle(pw_group) + + # Depth-wise convolution + self.conv_dw = nn.Conv2d( + mid_chs, mid_chs, kernel_size, padding=kernel_size // 2, + stride=stride, groups=mid_chs, bias=False) + self.bn2 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps) + + # Squeeze-and-excitation + if self.has_se: + self.se = SqueezeExcite(mid_chs, se_ratio) + + # Point-wise linear projection + self.conv_pwl = nn.Conv2d(mid_chs, out_chs, 1, groups=pw_group, bias=False) + self.bn3 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act_fn(x) + + # FIXME haven't tried this yet + # for channel shuffle when using groups with pointwise convs as per FBNet variants + if self.shuffle_type == "mid": + x = self.shuffle(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act_fn(x) + + # Squeeze-and-excitation + if self.has_se: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + x += residual + + # NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants + + return x + + +class GenMobileNet(nn.Module): + """ Generic Mobile Net + + An implementation of mobile optimized networks that covers: + * MobileNetV1 + * MobileNetV2 + * MNASNet A1, B1, and small + * FBNet A, B, and C + * ChamNet (arch details are murky) + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, + depth_multiplier=1.0, depth_divisor=8, min_depth=None, + bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, + drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False): + super(GenMobileNet, self).__init__() + self.num_classes = num_classes + self.depth_multiplier = depth_multiplier + self.drop_rate = drop_rate + self.act_fn = act_fn + self.num_features = num_features + + stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth) + self.conv_stem = nn.Conv2d(in_chans, stem_size, 3, padding=1, stride=2, bias=False) + self.bn1 = nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps) + in_chs = stem_size + + builder = _BlockBuilder( + depth_multiplier, depth_divisor, min_depth, + bn_momentum, bn_eps) + self.blocks = nn.Sequential(*builder(in_chs, block_args)) + in_chs = builder.in_chs + + if skip_head_conv: + self.conv_head = None + assert in_chs == self.num_features + else: + self.conv_head = nn.Conv2d(in_chs, self.num_features, 1, padding=0, stride=1, bias=False) + self.bn2 = nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps) + + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.classifier = nn.Linear(self.num_features, self.num_classes) + + for m in self.modules(): + _initialize_weight(m) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.num_classes = num_classes + del self.classifier + if num_classes: + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), num_classes) + else: + self.classifier = None + + def forward_features(self, x, pool=True): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act_fn(x) + x = self.blocks(x) + if self.conv_head is not None: + x = self.conv_head(x) + x = self.bn2(x) + x = self.act_fn(x) + if pool: + x = self.global_pool(x) + x = x.view(x.size(0), -1) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +def _gen_mnasnet_a1(depth_multiplier, num_classes=1000, **kwargs): + """Creates a mnasnet-a1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + depth_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r2_k3_s2_e6_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r4_k3_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320'], + ] + bn_momentum, bn_eps = _resolve_bn_params(kwargs) + model = GenMobileNet( + arch_def, + num_classes=num_classes, + stem_size=32, + depth_multiplier=depth_multiplier, + depth_divisor=8, + min_depth=None, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + **kwargs + ) + return model + + +def _gen_mnasnet_b1(depth_multiplier, num_classes=1000, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + depth_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r3_k5_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + bn_momentum, bn_eps = _resolve_bn_params(kwargs) + model = GenMobileNet( + arch_def, + num_classes=num_classes, + stem_size=32, + depth_multiplier=depth_multiplier, + depth_divisor=8, + min_depth=None, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + **kwargs + ) + return model + + +def _gen_mnasnet_small(depth_multiplier, num_classes=1000, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + depth_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + ['ds_r1_k3_s1_c8'], + ['ir_r1_k3_s2_e3_c16'], + ['ir_r2_k3_s2_e6_c16'], + ['ir_r4_k5_s2_e6_c32_se0.25'], + ['ir_r3_k3_s1_e6_c32_se0.25'], + ['ir_r3_k5_s2_e6_c88_se0.25'], + ['ir_r1_k3_s1_e6_c144'] + ] + bn_momentum, bn_eps = _resolve_bn_params(kwargs) + model = GenMobileNet( + arch_def, + num_classes=num_classes, + stem_size=8, + depth_multiplier=depth_multiplier, + depth_divisor=8, + min_depth=None, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + **kwargs + ) + return model + + +def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs): + """ + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['dsa_r1_k3_s1_c64'], + ['dsa_r2_k3_s2_c128'], + ['dsa_r2_k3_s2_c256'], + ['dsa_r6_k3_s2_c512'], + ['dsa_r2_k3_s2_c1024'], + ] + bn_momentum, bn_eps = _resolve_bn_params(kwargs) + model = GenMobileNet( + arch_def, + num_classes=num_classes, + stem_size=32, + num_features=1024, + depth_multiplier=depth_multiplier, + depth_divisor=8, + min_depth=None, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + act_fn=F.relu6, + skip_head_conv=True, + **kwargs + ) + return model + + +def _gen_mobilenet_v2(depth_multiplier, num_classes=1000, **kwargs): + """ Generate MobileNet-V2 network + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['ds_r1_k3_s1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r3_k3_s2_e6_c32'], + ['ir_r4_k3_s2_e6_c64'], + ['ir_r3_k3_s1_e6_c96'], + ['ir_r3_k3_s2_e6_c160'], + ['ir_r1_k3_s1_e6_c320'], + ] + bn_momentum, bn_eps = _resolve_bn_params(kwargs) + model = GenMobileNet( + arch_def, + num_classes=num_classes, + stem_size=32, + depth_multiplier=depth_multiplier, + depth_divisor=8, + min_depth=None, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + act_fn=F.relu6, + **kwargs + ) + return model + + +def _gen_chamnet_v1(depth_multiplier, num_classes=1000, **kwargs): + """ Generate Chameleon Network (ChamNet) + + Paper: https://arxiv.org/abs/1812.08934 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + FIXME: this a bit of an educated guess based on trunkd def in maskrcnn_benchmark + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c24'], + ['ir_r2_k7_s2_e4_c48'], + ['ir_r5_k3_s2_e7_c64'], + ['ir_r7_k5_s2_e12_c56'], + ['ir_r5_k3_s1_e8_c88'], + ['ir_r4_k3_s2_e7_c152'], + ['ir_r1_k3_s1_e10_c104'], + ] + bn_momentum, bn_eps = _resolve_bn_params(kwargs) + model = GenMobileNet( + arch_def, + num_classes=num_classes, + stem_size=32, + num_features=1280, # no idea what this is? try mobile/mnasnet default? + depth_multiplier=depth_multiplier, + depth_divisor=8, + min_depth=None, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + **kwargs + ) + return model + + +def _gen_chamnet_v2(depth_multiplier, num_classes=1000, **kwargs): + """ Generate Chameleon Network (ChamNet) + + Paper: https://arxiv.org/abs/1812.08934 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + FIXME: this a bit of an educated guess based on trunk def in maskrcnn_benchmark + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c24'], + ['ir_r4_k5_s2_e8_c32'], + ['ir_r6_k7_s2_e5_c48'], + ['ir_r3_k5_s2_e9_c56'], + ['ir_r6_k3_s1_e6_c56'], + ['ir_r6_k3_s2_e2_c152'], + ['ir_r1_k3_s1_e6_c112'], + ] + bn_momentum, bn_eps = _resolve_bn_params(kwargs) + model = GenMobileNet( + arch_def, + num_classes=num_classes, + stem_size=32, + num_features=1280, # no idea what this is? try mobile/mnasnet default? + depth_multiplier=depth_multiplier, + depth_divisor=8, + min_depth=None, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + **kwargs + ) + return model + + +def _gen_fbnetc(depth_multiplier, num_classes=1000, **kwargs): + """ FBNet-C + + Paper: https://arxiv.org/abs/1812.03443 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, + it was used to confirm some building block details + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], + ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], + ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], + ['ir_r4_k5_s2_e6_c184'], + ['ir_r1_k3_s1_e6_c352'], + ] + bn_momentum, bn_eps = _resolve_bn_params(kwargs) + model = GenMobileNet( + arch_def, + num_classes=num_classes, + stem_size=16, + num_features=1984, # paper suggests this, but is not 100% clear + depth_multiplier=depth_multiplier, + depth_divisor=8, + min_depth=None, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + **kwargs + ) + return model + + +def mnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.5. """ + default_cfg = default_cfgs['mnasnet0_50'] + model = _gen_mnasnet_b1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def mnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.75. """ + default_cfg = default_cfgs['mnasnet0_75'] + model = _gen_mnasnet_b1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + default_cfg = default_cfgs['mnasnet1_00'] + model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def mnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.4 """ + default_cfg = default_cfgs['mnasnet1_40'] + model = _gen_mnasnet_b1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def semnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ + default_cfg = default_cfgs['semnasnet0_50'] + model = _gen_mnasnet_a1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def semnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ + default_cfg = default_cfgs['semnasnet0_75'] + model = _gen_mnasnet_a1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + default_cfg = default_cfgs['semnasnet1_00'] + model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def semnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ + default_cfg = default_cfgs['semnasnet1_40'] + model = _gen_mnasnet_a1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def mnasnet_small(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MNASNet Small, depth multiplier of 1.0. """ + default_cfg = default_cfgs['mnasnet_small'] + model = _gen_mnasnet_small(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def mobilenetv1_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MobileNet V1 """ + default_cfg = default_cfgs['mobilenetv1_1_00'] + model = _gen_mobilenet_v1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def mobilenetv2_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MobileNet V2 """ + default_cfg = default_cfgs['mobilenetv2_1_00'] + model = _gen_mobilenet_v2(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def fbnetc_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): + """ FBNet-C """ + default_cfg = default_cfgs['fbnetc_1_00'] + model = _gen_fbnetc(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def chamnetv1_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): + """ ChamNet """ + default_cfg = default_cfgs['chamnetv1_1_00'] + model = _gen_chamnet_v1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def chamnetv2_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): + """ ChamNet """ + default_cfg = default_cfgs['chamnetv2_1_00'] + model = _gen_chamnet_v2(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model diff --git a/models/mnasnet.py b/models/mnasnet.py deleted file mode 100644 index 133e654e..00000000 --- a/models/mnasnet.py +++ /dev/null @@ -1,462 +0,0 @@ -""" MNASNet (a1, b1, and small) - -Based on offical TF implementation w/ round_channels, -decode_block_str, and model block args directly transferred -https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet - -Original paper: https://arxiv.org/pdf/1807.11626.pdf. - -Hacked together by Ross Wightman -""" - -import math -import re - -import torch -import torch.nn as nn -import torch.nn.functional as F -from models.helpers import load_pretrained -from models.adaptive_avgmax_pool import SelectAdaptivePool2d -from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD - -__all__ = ['MnasNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40', - 'semnasnet0_50', 'semnasnet0_75', 'semnasnet1_00', 'semnasnet1_40', - 'mnasnet_small'] - - -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv_stem', 'classifier': 'classifier', - **kwargs - } - -default_cfgs = { - 'mnasnet0_50': _cfg(url=''), - 'mnasnet0_75': _cfg(url=''), - 'mnasnet1_00': _cfg(url=''), - 'mnasnet1_40': _cfg(url=''), - 'semnasnet0_50': _cfg(url=''), - 'semnasnet0_75': _cfg(url=''), - 'semnasnet1_00': _cfg(url=''), - 'semnasnet1_40': _cfg(url=''), - 'mnasnet_small': _cfg(url=''), -} - -_BN_MOMENTUM_DEFAULT = 1 - 0.99 -_BN_EPS_DEFAULT = 1e-3 - - -def _round_channels(channels, depth_multiplier=1.0, depth_divisor=8, min_depth=None): - """Round number of filters based on depth multiplier.""" - multiplier = depth_multiplier - divisor = depth_divisor - min_depth = min_depth - if not multiplier: - return channels - - channels *= multiplier - min_depth = min_depth or divisor - new_channels = max(min_depth, int(channels + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_channels < 0.9 * channels: - new_channels += divisor - return new_channels - - -def _decode_block_str(block_str): - """Gets a MNasNet block through a string notation of arguments. - E.g. r2_k3_s2_e1_i32_o16_se0.25_noskip: - r - number of repeat blocks, - k - kernel size, - s - strides (1-9), - e - expansion ratio, - i - input filters, - o - output filters, - se - squeeze/excitation ratio - Args: - block_string: a string, a string representation of block arguments. - Returns: - A BlockArgs instance. - Raises: - ValueError: if the strides option is not correctly specified. - """ - assert isinstance(block_str, str) - ops = block_str.split('_') - options = {} - for op in ops: - splits = re.split(r'(\d.*)', op) - if len(splits) >= 2: - key, value = splits[:2] - options[key] = value - - if 's' not in options or len(options['s']) != 2: - raise ValueError('Strides options should be a pair of integers.') - - return dict( - kernel_size=int(options['k']), - num_repeat=int(options['r']), - in_chs=int(options['i']), - out_chs=int(options['o']), - exp_ratio=int(options['e']), - id_skip=('noskip' not in block_str), - se_ratio=float(options['se']) if 'se' in options else None, - stride=int(options['s'][0]) # TF impl passes a list of two strides - ) - - -def _decode_block_args(string_list): - block_args = [] - for block_str in string_list: - block_args.append(_decode_block_str(block_str)) - return block_args - - -def _initialize_weight(m): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2.0 / n)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - n = m.weight.size(1) - init_range = 1.0 / math.sqrt(n) - m.weight.data.uniform_(-init_range, init_range) - m.bias.data.zero_() - - -class MnasBlock(nn.Module): - """ MNASNet Inverted residual block w/ optional SE""" - - def __init__(self, in_chs, out_chs, kernel_size, stride, - exp_ratio=1.0, id_skip=True, se_ratio=0., - bn_momentum=0.1, bn_eps=1e-3, act_fn=F.relu): - super(MnasBlock, self).__init__() - exp_chs = int(in_chs * exp_ratio) - self.has_expansion = exp_ratio != 1 - self.has_se = se_ratio is not None and se_ratio > 0. - self.has_residual = id_skip and (in_chs == out_chs and stride == 1) - self.act_fn = act_fn - - # Pointwise expansion - if self.has_expansion: - self.conv_expand = nn.Conv2d(in_chs, exp_chs, 1, bias=False) - self.bn0 = nn.BatchNorm2d(exp_chs, momentum=bn_momentum, eps=bn_eps) - - # Depthwise convolution - self.conv_depthwise = nn.Conv2d( - exp_chs, exp_chs, kernel_size, padding=kernel_size // 2, - stride=stride, groups=exp_chs, bias=False) - self.bn1 = nn.BatchNorm2d(exp_chs, momentum=bn_momentum, eps=bn_eps) - - # Squeeze-and-excitation - if self.has_se: - num_reduced_ch = max(1, int(in_chs * se_ratio)) - self.conv_se_reduce = nn.Conv2d(exp_chs, num_reduced_ch, 1, bias=True) - self.conv_se_expand = nn.Conv2d(num_reduced_ch, exp_chs, 1, bias=True) - - # Pointwise projection - self.conv_project = nn.Conv2d(exp_chs, out_chs, 1, bias=False) - self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) - - def forward(self, x): - residual = x - # Pointwise expansion - if self.has_expansion: - x = self.conv_expand(x) - x = self.bn0(x) - x = self.act_fn(x) - # Depthwise convolution - x = self.conv_depthwise(x) - x = self.bn1(x) - x = self.act_fn(x) - # Squeeze-and-excitation - if self.has_se: - x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) - x_se = self.conv_se_reduce(x_se) - x_se = F.relu(x_se) - x_se = self.conv_se_expand(x_se) - x = F.sigmoid(x_se) * x - # Pointwise projection - x = self.conv_project(x) - x = self.bn2(x) - if self.has_residual: - return x + residual - else: - return x - - -class MnasNet(nn.Module): - """ MNASNet - - Based on offical TF implementation - https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet - - Original paper: https://arxiv.org/pdf/1807.11626.pdf. - """ - - def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, - depth_multiplier=1.0, depth_divisor=8, min_depth=None, - bn_momentum=_BN_MOMENTUM_DEFAULT, bn_eps=_BN_EPS_DEFAULT, drop_rate=0., - global_pool='avg', act_fn=F.relu): - super(MnasNet, self).__init__() - self.num_classes = num_classes - self.depth_multiplier = depth_multiplier - self.bn_momentum = bn_momentum - self.bn_eps = bn_eps - self.drop_rate = drop_rate - self.act_fn = act_fn - self.num_features = 1280 - - self.conv_stem = nn.Conv2d(in_chans, stem_size, 3, padding=1, stride=2, bias=False) - self.bn0 = nn.BatchNorm2d(stem_size, momentum=self.bn_momentum, eps=self.bn_eps) - - blocks = [] - for i, a in enumerate(block_args): - print(a) #FIXME debug - a['in_chs'] = _round_channels(a['in_chs'], depth_multiplier, depth_divisor, min_depth) - a['out_chs'] = _round_channels(a['out_chs'], depth_multiplier, depth_divisor, min_depth) - blocks.append(self._make_stack(**a)) - out_chs = a['out_chs'] - self.blocks = nn.Sequential(*blocks) - - self.conv_head = nn.Conv2d(out_chs, self.num_features, 1, padding=0, stride=1, bias=False) - self.bn1 = nn.BatchNorm2d(self.num_features, momentum=self.bn_momentum, eps=self.bn_eps) - - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.classifier = nn.Linear(self.num_features, self.num_classes) - - for m in self.modules(): - _initialize_weight(m) - - def _make_stack(self, in_chs, out_chs, kernel_size, stride, - exp_ratio, id_skip, se_ratio, num_repeat): - blocks = [MnasBlock( - in_chs, out_chs, kernel_size, stride, exp_ratio, id_skip, se_ratio, - bn_momentum=self.bn_momentum, bn_eps=self.bn_eps)] - for _ in range(1, num_repeat): - blocks += [MnasBlock( - out_chs, out_chs, kernel_size, 1, exp_ratio, id_skip, se_ratio, - bn_momentum=self.bn_momentum, bn_eps=self.bn_eps)] - return nn.Sequential(*blocks) - - def get_classifier(self): - return self.classifier - - def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.num_classes = num_classes - del self.classifier - if num_classes: - self.classifier = nn.Linear( - self.num_features * self.global_pool.feat_mult(), num_classes) - else: - self.classifier = None - - def forward_features(self, x, pool=True): - x = self.conv_stem(x) - x = self.bn0(x) - x = self.act_fn(x) - x = self.blocks(x) - x = self.conv_head(x) - x = self.bn1(x) - x = self.act_fn(x) - if pool: - x = self.global_pool(x) - x = x.view(x.size(0), -1) - return x - - def forward(self, x): - x = self.forward_features(x) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) - return self.classifier(x) - - -def mnasnet_a1(depth_multiplier, num_classes=1000, **kwargs): - """Creates a mnasnet-a1 model. - Args: - depth_multiplier: multiplier to number of channels per layer. - """ - - # defs from https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py - block_defs = [ - 'r1_k3_s11_e1_i32_o16_noskip', - 'r2_k3_s22_e6_i16_o24', - 'r3_k5_s22_e3_i24_o40_se0.25', - 'r4_k3_s22_e6_i40_o80', - 'r2_k3_s11_e6_i80_o112_se0.25', - 'r3_k5_s22_e6_i112_o160_se0.25', - 'r1_k3_s11_e6_i160_o320' - ] - block_args = _decode_block_args(block_defs) - model = MnasNet( - block_args, - num_classes=num_classes, - depth_multiplier=depth_multiplier, - depth_divisor=8, - min_depth=None, - stem_size=32, - bn_momentum=_BN_MOMENTUM_DEFAULT, - bn_eps=_BN_EPS_DEFAULT, - #drop_rate=0.2, - **kwargs - ) - return model - - -def mnasnet_b1(depth_multiplier, num_classes=1000, **kwargs): - """Creates a mnasnet-b1 model. - Args: - depth_multiplier: multiplier to number of channels per layer. - """ - # from https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py - blocks_defs = [ - 'r1_k3_s11_e1_i32_o16_noskip', - 'r3_k3_s22_e3_i16_o24', - 'r3_k5_s22_e3_i24_o40', - 'r3_k5_s22_e6_i40_o80', - 'r2_k3_s11_e6_i80_o96', - 'r4_k5_s22_e6_i96_o192', - 'r1_k3_s11_e6_i192_o320_noskip' - ] - block_args = _decode_block_args(blocks_defs) - model = MnasNet( - block_args, - num_classes=num_classes, - depth_multiplier=depth_multiplier, - depth_divisor=8, - min_depth=None, - stem_size=32, - bn_momentum=_BN_MOMENTUM_DEFAULT, - bn_eps=_BN_EPS_DEFAULT, - #drop_rate=0.2, - **kwargs - ) - return model - - -def mnasnet_small(depth_multiplier, num_classes=1000, **kwargs): - """Creates a mnasnet-b1 model. - Args: - depth_multiplier: multiplier to number of channels per layer. - """ - # from https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py - blocks_defs = [ - 'r1_k3_s11_e1_i16_o8', - 'r1_k3_s22_e3_i8_o16', - 'r2_k3_s22_e6_i16_o16', - 'r4_k5_s22_e6_i16_o32_se0.25', - 'r3_k3_s11_e6_i32_o32_se0.25', - 'r3_k5_s22_e6_i32_o88_se0.25', - 'r1_k3_s11_e6_i88_o144' - ] - block_args = _decode_block_args(blocks_defs) - model = MnasNet( - block_args, - num_classes=num_classes, - depth_multiplier=depth_multiplier, - depth_divisor=8, - min_depth=None, - stem_size=8, - bn_momentum=_BN_MOMENTUM_DEFAULT, - bn_eps=_BN_EPS_DEFAULT, - #drop_rate=0.2, - **kwargs - ) - return model - - -def mnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): - """ MNASNet B1, depth multiplier of 0.5. """ - default_cfg = default_cfgs['mnasnet0_50'] - model = mnasnet_b1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def mnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs): - """ MNASNet B1, depth multiplier of 0.75. """ - default_cfg = default_cfgs['mnasnet0_50'] - model = mnasnet_b1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): - """ MNASNet B1, depth multiplier of 1.0. """ - default_cfg = default_cfgs['mnasnet0_50'] - model = mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def mnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs): - """ MNASNet B1, depth multiplier of 1.4 """ - default_cfg = default_cfgs['mnasnet0_50'] - model = mnasnet_b1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def semnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): - """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ - default_cfg = default_cfgs['mnasnet0_50'] - model = mnasnet_a1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def semnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs): - """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ - default_cfg = default_cfgs['mnasnet0_50'] - model = mnasnet_a1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): - """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ - default_cfg = default_cfgs['mnasnet0_50'] - model = mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def semnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs): - """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ - default_cfg = default_cfgs['mnasnet0_50'] - model = mnasnet_a1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def mnasnet_small(num_classes, in_chans=3, pretrained=False, **kwargs): - """ MNASNet Small, depth multiplier of 1.0. """ - default_cfg = default_cfgs['mnasnet_small'] - model = mnasnet_small(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model diff --git a/models/model_factory.py b/models/model_factory.py index 38e79291..99cfa931 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -8,12 +8,21 @@ from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresn seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d from models.xception import xception from models.pnasnet import pnasnet5large -from models.mnasnet import mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40,\ - semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, mnasnet_small +from models.genmobilenet import \ + mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40,\ + semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, mnasnet_small,\ + mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00 from models.helpers import load_checkpoint +def _is_genmobilenet(name): + genmobilenets = ['mnasnet', 'semnasnet', 'fbnet', 'chamnet', 'mobilenet'] + if any([name.startswith(x) for x in genmobilenets]): + return True + return False + + def create_model( model_name='resnet50', pretrained=None, @@ -24,6 +33,14 @@ def create_model( margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained) + # Not all models have support for batchnorm params passed as args, only genmobilenet variants + # FIXME better way to do this without pushing support into every other model fn? + supports_bn_params = _is_genmobilenet(model_name) + if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]): + kwargs.pop('bn_tf', None) + kwargs.pop('bn_momentum', None) + kwargs.pop('bn_eps', None) + if model_name in globals(): create_fn = globals()[model_name] model = create_fn(**margs, **kwargs) diff --git a/optim/__init__.py b/optim/__init__.py index 62ab43ba..2fdbcab9 100644 --- a/optim/__init__.py +++ b/optim/__init__.py @@ -1,3 +1,4 @@ from optim.adabound import AdaBound from optim.nadam import Nadam +from optim.rmsprop_tf import RMSpropTF from optim.optim_factory import create_optimizer \ No newline at end of file diff --git a/optim/optim_factory.py b/optim/optim_factory.py index d207dbcf..a77d668f 100644 --- a/optim/optim_factory.py +++ b/optim/optim_factory.py @@ -1,5 +1,5 @@ from torch import optim as optim -from optim import Nadam, AdaBound +from optim import Nadam, AdaBound, RMSpropTF def create_optimizer(args, parameters): @@ -24,6 +24,10 @@ def create_optimizer(args, parameters): optimizer = optim.RMSprop( parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=args.weight_decay) + elif args.opt.lower() == 'rmsproptf': + optimizer = RMSpropTF( + parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, + momentum=args.momentum, weight_decay=args.weight_decay) else: assert False and "Invalid optimizer" raise ValueError diff --git a/optim/rmsprop_tf.py b/optim/rmsprop_tf.py new file mode 100644 index 00000000..b88238ea --- /dev/null +++ b/optim/rmsprop_tf.py @@ -0,0 +1,105 @@ +import torch +from torch.optim import Optimizer + + +class RMSpropTF(Optimizer): + """Implements RMSprop algorithm (TensorFlow style epsilon) + + NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt + to closer match Tensorflow for matching hyper-params. + + Proposed by G. Hinton in his + `course `_. + + The centered version first appears in `Generating Sequences + With Recurrent Neural Networks `_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing constant (default: 0.99) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + + """ + + def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= momentum: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) + super(RMSpropTF, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RMSpropTF, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('momentum', 0) + group.setdefault('centered', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.zeros_like(p.data) + if group['momentum'] > 0: + state['momentum_buffer'] = torch.zeros_like(p.data) + if group['centered']: + state['grad_avg'] = torch.zeros_like(p.data) + + square_avg = state['square_avg'] + alpha = group['alpha'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], p.data) + + square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) + + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.mul_(alpha).add_(1 - alpha, grad) + avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() + else: + avg = square_avg.add(group['eps']).sqrt_() + + if group['momentum'] > 0: + buf = state['momentum_buffer'] + buf.mul_(group['momentum']).addcdiv_(grad, avg) + p.data.add_(-group['lr'], buf) + else: + p.data.addcdiv_(-group['lr'], grad, avg) + + return loss diff --git a/train.py b/train.py index 03076d88..08be0af8 100644 --- a/train.py +++ b/train.py @@ -77,10 +77,16 @@ parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') -parser.add_argument('--weight-decay', type=float, default=0.0001, metavar='M', +parser.add_argument('--weight-decay', type=float, default=0.0001, help='weight decay (default: 0.0001)') -parser.add_argument('--smoothing', type=float, default=0.1, metavar='M', +parser.add_argument('--smoothing', type=float, default=0.1, help='label smoothing (default: 0.1)') +parser.add_argument('--bn-tf', action='store_true', default=False, + help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') +parser.add_argument('--bn-momentum', type=float, default=None, + help='BatchNorm momentum override (if not None)') +parser.add_argument('--bn-eps', type=float, default=None, + help='BatchNorm epsilon override (if not None)') parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument('--log-interval', type=int, default=50, metavar='N', @@ -154,6 +160,9 @@ def main(): num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, + bn_tf=args.bn_tf, + bn_momentum=args.bn_momentum, + bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) print('Model %s created, param count: %d' %