""" Generic EfficientNets
A generic class with building blocks to support a variety of models with efficient architectures:
* EfficientNet (B0-B7)
* EfficientNet-EdgeTPU
* EfficientNet-CondConv
* MixNet (Small, Medium, and Large)
* MnasNet B1, A1 (SE), Small
* MobileNet V1, V2, and V3
* FBNet-C (TODO A & B)
* ChamNet (TODO still guessing at architecture definition)
* Single-Path NAS Pixel1
* And likely more...
TODO not all combinations and variations have been tested. Currently working on training hyper-params...
Hacked together by Ross Wightman
import math
import re
import logging
from copy import deepcopy
from functools import partial
from collections import OrderedDict, defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.activations import Swish, sigmoid, HardSwish, hard_sigmoid
from .registry import register_model, model_entrypoint
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .conv2d_layers import select_conv2d
from .layers import Flatten
__all__ = ['GenEfficientNet']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'first_conv': 'conv_stem', 'classifier': 'classifier',
default_cfgs = {
'mnasnet_050': _cfg(url=''),
'mnasnet_075': _cfg(url=''),
'mnasnet_100': _cfg(
'mnasnet_140': _cfg(url=''),
'semnasnet_050': _cfg(url=''),
'semnasnet_075': _cfg(url=''),
'semnasnet_100': _cfg(
'semnasnet_140': _cfg(url=''),
'mnasnet_small': _cfg(url=''),
'mobilenetv1_100': _cfg(url=''),
'mobilenetv2_100': _cfg(url=''),
'mobilenetv3_050': _cfg(url=''),
'mobilenetv3_075': _cfg(url=''),
'mobilenetv3_100': _cfg(
'chamnetv1_100': _cfg(url=''),
'chamnetv2_100': _cfg(url=''),
'fbnetc_100': _cfg(
'spnasnet_100': _cfg(
'efficientnet_b0': _cfg(
'efficientnet_b1': _cfg(
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'efficientnet_b2': _cfg(
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'efficientnet_b3': _cfg(
url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'efficientnet_b4': _cfg(
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'efficientnet_b5': _cfg(
url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'efficientnet_b6': _cfg(
url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'efficientnet_b7': _cfg(
url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'efficientnet_es': _cfg(
'efficientnet_em': _cfg(
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'efficientnet_el': _cfg(
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'efficientnet_cc_b0_4e': _cfg(url=''),
'efficientnet_cc_b0_8e': _cfg(url=''),
'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b0': _cfg(
input_size=(3, 224, 224)),
'tf_efficientnet_b1': _cfg(
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b2': _cfg(
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'tf_efficientnet_b3': _cfg(
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b4': _cfg(
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b5': _cfg(
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'tf_efficientnet_b6': _cfg(
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'tf_efficientnet_b7': _cfg(
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'tf_efficientnet_es': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 224, 224), ),
'tf_efficientnet_em': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_el': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_cc_b0_4e': _cfg(
'tf_efficientnet_cc_b0_8e': _cfg(
'tf_efficientnet_cc_b1_8e': _cfg(
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'mixnet_s': _cfg(
'mixnet_m': _cfg(
'mixnet_l': _cfg(
'mixnet_xl': _cfg(
'mixnet_xxl': _cfg(),
'tf_mixnet_s': _cfg(
'tf_mixnet_m': _cfg(
'tf_mixnet_l': _cfg(
_DEBUG = True
# Default args for PyTorch BN impl
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
# NOTE: momentum varies btw .99 and .9997 depending on source
# .99 in official TF TPU impl
# .9997 (/w .999 in search space) for paper
def _resolve_bn_args(kwargs):
bn_args = _BN_ARGS_TF.copy() if kwargs.pop('bn_tf', False) else _BN_ARGS_PT.copy()
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args
def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
channels *= multiplier
channel_min = channel_min or divisor
new_channels = max(
int(channels + divisor / 2) // divisor * divisor,
# Make sure that round down does not go down by more than 10%.
if new_channels < 0.9 * channels:
new_channels += divisor
return new_channels
def _parse_ksize(ss):
if ss.isdigit():
return int(ss)
return [int(k) for k in ss.split('.')]
def _decode_block_str(block_str):
""" Decode block definition string
Gets a list of block arg (dicts) through a string notation of arguments.
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
All args can exist in any order with the exception of the leading string which
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
n - activation fn ('re', 'r6', 'hs', or 'sw')
block_str: a string representation of block arguments.
A list of block args (dicts)
ValueError: if the string def not properly specified (TODO)
assert isinstance(block_str, str)
ops = block_str.split('_')
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
noskip = False
for op in ops:
# string options being checked on individual basis, combine if they grow
if op == 'noskip':
noskip = True
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
if v == 're':
value = nn.ReLU
elif v == 'r6':
value = nn.ReLU6
elif v == 'hs':
value = HardSwish
elif v == 'sw':
value = Swish
options[key] = value
# all numeric options
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# if act_layer is None, the model default (passed to model init) will be used
act_layer = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
block_args = dict(
se_ratio=float(options['se']) if 'se' in options else None,
num_experts=int(options['cc']) if 'cc' in options else 0
elif block_type == 'ds' or block_type == 'dsa':
block_args = dict(
se_ratio=float(options['se']) if 'se' in options else None,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
elif block_type == 'er':
block_args = dict(
se_ratio=float(options['se']) if 'se' in options else None,
elif block_type == 'cn':
block_args = dict(
assert False, 'Unknown block type (%s)' % block_type
return block_args, num_repeat
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
""" Per-stage depth scaling
Scales the block repeats in each stage. This depth scaling impl maintains
compatibility with the EfficientNet scaling method, while allowing sensible
scaling for other models that may have multiple block arg definitions in each stage.
# We scale the total repeat count for each stage, there may be multiple
# block arg defs per stage so we need to sum.
num_repeat = sum(repeats)
if depth_trunc == 'round':
# Truncating to int by rounding allows stages with few repeats to remain
# proportionally smaller for longer. This is a good choice when stage definitions
# include single repeat stages that we'd prefer to keep that way as long as possible
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
# The default for EfficientNet truncates repeats to int via 'ceil'.
# Any multiplier > 1.0 will result in an increased depth for every stage.
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
# Proportionally distribute repeat count scaling to each block definition in the stage.
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
# The first block makes less sense to repeat in most of the arch definitions.
repeats_scaled = []
for r in repeats[::-1]:
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
num_repeat -= r
num_repeat_scaled -= rs
repeats_scaled = repeats_scaled[::-1]
# Apply the calculated scaling to each block arg in the stage
sa_scaled = []
for ba, rep in zip(stack_args, repeats_scaled):
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
return sa_scaled
def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1):
arch_args = []
for stack_idx, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list)
stack_args = []
repeats = []
for block_str in block_strings:
assert isinstance(block_str, str)
ba, rep = _decode_block_str(block_str)
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
ba['num_experts'] *= experts_multiplier
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
return arch_args
def swish_jit_fwd(x):
return x.mul(torch.sigmoid(x))
def swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class SwishJitAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish
Inspired by conversation btw Jeremy Howard & Adam Pazske
def forward(ctx, x):
return swish_jit_fwd(x)
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return swish_jit_bwd(x, grad_output)
def swish(x, inplace=False):
# inplace ignored
return SwishJitAutoFn.apply(x)
def swish(x, inplace=False):
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
def sigmoid(x, inplace=False):
return x.sigmoid_() if inplace else x.sigmoid()
def hard_swish(x, inplace=False):
if inplace:
return x.mul_(F.relu6(x + 3.) / 6.)
return x * F.relu6(x + 3.) / 6.
def hard_sigmoid(x, inplace=False):
if inplace:
return x.add_(3.).clamp_(0., 6.).div_(6.)
return F.relu6(x + 3.) / 6.
class _BlockBuilder:
""" Build Trunk Blocks
This ended up being somewhat of a cross between
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=None, se_gate_fn=sigmoid, se_reduce_mid=False,
norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, drop_connect_rate=0., feature_location='',
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.output_stride = output_stride
self.pad_type = pad_type
self.act_layer = act_layer
self.se_gate_fn = se_gate_fn
self.se_reduce_mid = se_reduce_mid
self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs
self.drop_connect_rate = drop_connect_rate
self.feature_location = feature_location
assert feature_location in ('pre_pwl', 'post_exp', '')
self.verbose = verbose
# state updated during build, consumed by model
self.in_chs = None
self.features = OrderedDict()
def _round_channels(self, chs):
return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
def _make_block(self, ba, block_idx, block_count):
drop_connect_rate = self.drop_connect_rate * block_idx / block_count
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
# FIXME this is a hack to work around mismatch in origin impl input filters
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['norm_layer'] = self.norm_layer
ba['norm_kwargs'] = self.norm_kwargs
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None
if bt == 'ir':
ba['drop_connect_rate'] = drop_connect_rate
ba['se_gate_fn'] = self.se_gate_fn
ba['se_reduce_mid'] = self.se_reduce_mid
if self.verbose:
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
ba['drop_connect_rate'] = drop_connect_rate
if self.verbose:
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
block = DepthwiseSeparableConv(**ba)
elif bt == 'er':
ba['drop_connect_rate'] = drop_connect_rate
ba['se_gate_fn'] = self.se_gate_fn
ba['se_reduce_mid'] = self.se_reduce_mid
if self.verbose:
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
block = EdgeResidual(**ba)
elif bt == 'cn':
if self.verbose:
logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)))
block = ConvBnAct(**ba)
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
return block
def __call__(self, in_chs, model_block_args):
""" Build the blocks
in_chs: Number of input-channels passed to first block
model_block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
List of block stacks (each stack wrapped in nn.Sequential)
if self.verbose:
logging.info('Building model trunk with %d stages...' % len(model_block_args))
self.in_chs = in_chs
total_block_count = sum([len(x) for x in model_block_args])
total_block_idx = 0
current_stride = 2
current_dilation = 1
feature_idx = 0
stages = []
# outer list of block_args defines the stacks ('stages' by some conventions)
for stage_idx, stage_block_args in enumerate(model_block_args):
last_stack = stage_idx == (len(model_block_args) - 1)
if self.verbose:
logging.info('Stack: {}'.format(stage_idx))
assert isinstance(stage_block_args, list)
blocks = []
# each stack (stage) contains a list of block arguments
for block_idx, block_args in enumerate(stage_block_args):
last_block = block_idx == (len(stage_block_args) - 1)
extract_features = '' # No features extracted
if self.verbose:
logging.info(' Block: {}'.format(block_idx))
# Sort out stride, dilation, and feature extraction details
assert block_args['stride'] in (1, 2)
if block_idx >= 1:
# only the first block in any stack can have a stride > 1
block_args['stride'] = 1
do_extract = False
if self.feature_location == 'pre_pwl':
if last_block:
next_stage_idx = stage_idx + 1
if next_stage_idx >= len(model_block_args):
do_extract = True
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
elif self.feature_location == 'post_exp':
if block_args['stride'] > 1 or (last_stack and last_block) :
do_extract = True
if do_extract:
extract_features = self.feature_location
next_dilation = current_dilation
if block_args['stride'] > 1:
next_output_stride = current_stride * block_args['stride']
if next_output_stride > self.output_stride:
next_dilation = current_dilation * block_args['stride']
block_args['stride'] = 1
if self.verbose:
logging.info(' Converting stride to dilation to maintain output_stride=={}'.format(
current_stride = next_output_stride
block_args['dilation'] = current_dilation
if next_dilation != current_dilation:
current_dilation = next_dilation
# create the block
block = self._make_block(block_args, total_block_idx, total_block_count)
# stash feature module name and channel info for model feature extraction
if extract_features:
feature_module = block.feature_module(extract_features)
if feature_module:
feature_module = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_module
feature_channels = block.feature_channels(extract_features)
self.features[feature_idx] = dict(
feature_idx += 1
total_block_idx += 1 # incr global block idx (across all stacks)
return stages
def _init_weight_goog(m):
# weight init as per Tensorflow Official impl
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # fan-out
m.weight.data.normal_(0, math.sqrt(2.0 / n))
if m.bias is not None:
elif isinstance(m, nn.BatchNorm2d):
elif isinstance(m, nn.Linear):
n = m.weight.size(0) # fan-out
init_range = 1.0 / math.sqrt(n)
m.weight.data.uniform_(-init_range, init_range)
def _init_weight_default(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
elif isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
def drop_connect(inputs, training=False, drop_connect_rate=0.):
"""Apply drop connect."""
if not training:
return inputs
keep_prob = 1 - drop_connect_rate
random_tensor = keep_prob + torch.rand(
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
random_tensor.floor_() # binarize
output = inputs.div(keep_prob) * random_tensor
return output
class ChannelShuffle(nn.Module):
# FIXME haven't used yet
def __init__(self, groups):
super(ChannelShuffle, self).__init__()
self.groups = groups
def forward(self, x):
"""Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
N, C, H, W = x.size()
g = self.groups
assert C % g == 0, "Incompatible group size {} for input channel {}".format(
g, C
return (
x.view(N, g, int(C / g), H, W)
.permute(0, 2, 1, 3, 4)
.view(N, C, H, W)
class SqueezeExcite(nn.Module):
def __init__(self, in_chs, reduce_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid):
super(SqueezeExcite, self).__init__()
self.gate_fn = gate_fn
reduced_chs = reduce_chs or in_chs
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
def forward(self, x):
# NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance
x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
x = x * self.gate_fn(x_se)
return x
class ConvBnAct(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT,):
super(ConvBnAct, self).__init__()
assert stride in [1, 2]
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn1 = norm_layer(out_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
def feature_module(self, location):
return 'act1'
def feature_channels(self, location):
return self.conv.out_channels
def forward(self, x):
x = self.conv(x)
x = self.bn1(x)
x = self.act1(x)
return x
class EdgeResidual(nn.Module):
""" Residual block with expansion convolution followed by pointwise-linear w/ stride"""
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid,
norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, drop_connect_rate=0.):
super(EdgeResidual, self).__init__()
mid_chs = int(fake_in_chs * exp_ratio) if fake_in_chs > 0 else int(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
# Expansion convolution
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if self.has_se:
se_base_chs = mid_chs if se_reduce_mid else in_chs
self.se = SqueezeExcite(
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_layer=act_layer, gate_fn=se_gate_fn)
# Point-wise linear projection
self.conv_pwl = select_conv2d(
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
def feature_module(self, location):
if location == 'post_exp':
return 'act1'
return 'conv_pwl'
def feature_channels(self, location):
if location == 'post_exp':
return self.conv_exp.out_channels
# location == 'pre_pw'
return self.conv_pwl.in_channels
def forward(self, x):
residual = x
# Expansion convolution
x = self.conv_exp(x)
x = self.bn1(x)
x = self.act1(x)
# Squeeze-and-excitation
if self.has_se:
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn2(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class DepthwiseSeparableConv(nn.Module):
""" DepthwiseSeparable block
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
factor of 1.0. This is an alternative to having a IR with an optional first pw conv.
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid,
norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
assert stride in [1, 2]
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_connect_rate = drop_connect_rate
self.conv_dw = select_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
self.bn1 = norm_layer(in_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if self.has_se:
self.se = SqueezeExcite(
in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_layer=act_layer, gate_fn=se_gate_fn)
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
def feature_module(self, location):
# no expansion in this block, pre pw only feature extraction point
return 'conv_pw'
def feature_channels(self, location):
return self.conv_pw.in_channels
def forward(self, x):
residual = x
x = self.conv_dw(x)
x = self.bn1(x)
x = self.act1(x)
if self.has_se:
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
x = self.act2(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class InvertedResidual(nn.Module):
""" Inverted residual block w/ optional SE and CondConv routing"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid,
norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT,
num_experts=0, drop_connect_rate=0.):
super(InvertedResidual, self).__init__()
mid_chs = int(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
self.num_experts = num_experts
extra_args = dict()
if num_experts > 0:
extra_args = dict(num_experts=self.num_experts)
self.routing_fn = nn.Linear(in_chs, self.num_experts)
self.routing_act = torch.sigmoid
# Point-wise expansion
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **extra_args)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Depth-wise convolution
self.conv_dw = select_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **extra_args)
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True)
# Squeeze-and-excitation
if self.has_se:
se_base_chs = mid_chs if se_reduce_mid else in_chs
self.se = SqueezeExcite(
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_layer=act_layer, gate_fn=se_gate_fn)
# Point-wise linear projection
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **extra_args)
self.bn3 = norm_layer(out_chs, **norm_kwargs)
def feature_module(self, location):
if location == 'post_exp':
return 'act1'
return 'conv_pwl'
def feature_channels(self, location):
if location == 'post_exp':
return self.conv_pw.out_channels
# location == 'pre_pw'
return self.conv_pwl.in_channels
def forward(self, x):
residual = x
conv_pw, conv_dw, conv_pwl = self.conv_pw, self.conv_dw, self.conv_pwl
if self.num_experts > 0:
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
routing_weights = self.routing_act(self.routing_fn(pooled_inputs))
conv_pw = partial(self.conv_pw, routing_weights=routing_weights)
conv_dw = partial(self.conv_dw, routing_weights=routing_weights)
conv_pwl = partial(self.conv_pwl, routing_weights=routing_weights)
# Point-wise expansion
x = conv_pw(x)
x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = conv_dw(x)
x = self.bn2(x)
x = self.act2(x)
# Squeeze-and-excitation
if self.has_se:
x = self.se(x)
# Point-wise linear projection
x = conv_pwl(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class _GenEfficientNet(nn.Module):
""" Generic EfficientNet Base
def __init__(self, block_args, in_chans=3, stem_size=32,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
se_gate_fn=sigmoid, se_reduce_mid=False, norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT,
super(_GenEfficientNet, self).__init__()
self.drop_rate = drop_rate
self._in_chs = in_chans
# Stem
stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
# Middle stages (IR/ER/DS Blocks)
builder = _BlockBuilder(
channel_multiplier, channel_divisor, channel_min,
output_stride, pad_type, act_layer, se_gate_fn, se_reduce_mid,
norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features
self._in_chs = builder.in_chs
def as_sequential(self):
layers = [self.conv_stem, self.bn1, self.act1]
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x)
return x
class GenEfficientNet(_GenEfficientNet):
""" Generic EfficientNet
An implementation of efficient network architectures, in many cases mobile optimized networks:
* MobileNet-V1
* MobileNet-V2
* MobileNet-V3
* MnasNet A1, B1, and small
* FBNet A, B, and C
* ChamNet (arch details are murky)
* Single-Path NAS Pixel1
* EfficientNet B0-B7
* MixNet S, M, L
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
se_gate_fn=sigmoid, se_reduce_mid=False,
norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT,
global_pool='avg', head_conv='default', weight_init='goog'):
self.num_classes = num_classes
self.num_features = num_features
super(GenEfficientNet, self).__init__( # FIXME it would be nice if Python made this nicer
block_args, in_chans=in_chans, stem_size=stem_size,
pad_type=pad_type, act_layer=act_layer, drop_rate=drop_rate, drop_connect_rate=drop_connect_rate,
channel_multiplier=channel_multiplier, channel_divisor=channel_divisor, channel_min=channel_min,
se_gate_fn=se_gate_fn, se_reduce_mid=se_reduce_mid, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
# Head + Pooling
self.conv_head = None
self.global_pool = None
self.act2 = None
self.forward_head = None
self.head_conv = head_conv
if head_conv == 'efficient':
self._create_head_efficient(global_pool, pad_type, act_layer)
elif head_conv == 'default':
self._create_head_default(global_pool, pad_type, act_layer, norm_layer, norm_kwargs)
# Classifier
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes)
for m in self.modules():
if weight_init == 'goog':
def _create_head_default(self, global_pool, pad_type, act_layer, norm_layer, norm_kwargs):
self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
self.bn2 = norm_layer(self.num_features, **norm_kwargs)
self.act2 = act_layer(inplace=True)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
def _create_head_efficient(self, global_pool, pad_type, act_layer):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
self.act2 = act_layer(inplace=True)
def _forward_head_default(self, x):
x = self.conv_head(x)
x = self.bn2(x)
x = self.act2(x)
return x
def _forward_head_efficient(self, x):
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
return x
def as_sequential(self):
layers = [self.conv_stem, self.bn1, self.act1]
if self.head_conv == 'efficient':
layers.extend([self.global_pool, self.conv_head, self.act2])
layers.extend([self.conv_head, self.bn2, self.act2])
if self.global_pool is not None:
layers.extend([Flatten(), nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
def get_classifier(self):
return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
del self.classifier
if num_classes:
self.classifier = nn.Linear(
self.num_features * self.global_pool.feat_mult(), num_classes)
self.classifier = None
def forward_features(self, x):
x = super(GenEfficientNet, self).forward(x)
if self.head_conv == 'efficient':
x = self._forward_head_efficient(x)
elif self.head_conv == 'default':
x = self._forward_head_default(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.global_pool is not None and x.shape[-1] > 1 or x.shape[-2] > 1:
x = self.global_pool(x)
x = x.flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
class GenEfficientNetFeatures(_GenEfficientNet):
""" Generic EfficientNet Feature Extractor
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
se_gate_fn=sigmoid, se_reduce_mid=False, norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT,
# validate and modify block arguments and out indices for feature extraction
num_stages = max(out_indices) + 1 # FIXME reduce num stages created if not needed
#assert len(block_args) >= num_stages - 1
#block_args = block_args[:num_stages - 1]
# FIXME it would be nice if Python made this nicer without using kwargs and erasing IDE hints, etc
super(GenEfficientNetFeatures, self).__init__(
block_args, in_chans=in_chans, stem_size=stem_size,
output_stride=output_stride, pad_type=pad_type, act_layer=act_layer,
drop_rate=drop_rate, drop_connect_rate=drop_connect_rate, feature_location=feature_location,
channel_multiplier=channel_multiplier, channel_divisor=channel_divisor, channel_min=channel_min,
se_gate_fn=se_gate_fn, se_reduce_mid=se_reduce_mid, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
for m in self.modules():
if weight_init == 'goog':
if _DEBUG:
for k, v in self.feature_info.items():
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
self._feature_outputs = None
def _collect_output_hook(self, name, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
if isinstance(x, tuple):
x = x[0] # unwrap input tuple
self._feature_outputs[x.device][name] = x
def _get_output(self, device):
output = tuple(self._feature_outputs[device].values())[::-1]
self._feature_outputs[device] = OrderedDict()
return output
def _register_hooks(self, hooks):
# setup feature hooks
modules = {k: v for k, v in self.named_modules()}
for h in hooks:
hook_name = h['name']
m = modules[hook_name]
hook_fn = partial(self._collect_output_hook, hook_name)
if h['type'] == 'forward_pre':
self._feature_outputs = defaultdict(OrderedDict)
def feature_channels(self, idx=None):
if isinstance(idx, int):
return self.feature_info[idx]['num_chs']
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
def forward(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
return self._get_output(x.device)
def _create_model(model_kwargs, default_cfg, pretrained=False):
if model_kwargs.pop('features_only', False):
load_strict = False
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
model_class = GenEfficientNetFeatures
load_strict = True
model_class = GenEfficientNet
model = model_class(**model_kwargs)
model.default_cfg = default_cfg
if pretrained:
num_classes=model_kwargs.get('num_classes', 0),
in_chans=model_kwargs.get('in_chans', 3),
return model
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a mnasnet-a1 model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
Paper: https://arxiv.org/pdf/1807.11626.pdf.
channel_multiplier: multiplier to number of channels per layer.
arch_def = [
# stage 0, 112x112 in
# stage 1, 112x112 in
# stage 2, 56x56 in
# stage 3, 28x28 in
# stage 4, 14x14in
# stage 5, 14x14in
# stage 6, 7x7 in
model_kwargs = dict(
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a mnasnet-b1 model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
Paper: https://arxiv.org/pdf/1807.11626.pdf.
channel_multiplier: multiplier to number of channels per layer.
arch_def = [
# stage 0, 112x112 in
# stage 1, 112x112 in
# stage 2, 56x56 in
# stage 3, 28x28 in
# stage 4, 14x14in
# stage 5, 14x14in
# stage 6, 7x7 in
model_kwargs = dict(
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a mnasnet-b1 model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
Paper: https://arxiv.org/pdf/1807.11626.pdf.
channel_multiplier: multiplier to number of channels per layer.
arch_def = [
model_kwargs = dict(
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_mobilenet_v1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
""" Generate MobileNet-V1 network
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
Paper: https://arxiv.org/abs/1801.04381
arch_def = [
model_kwargs = dict(
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_mobilenet_v2(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
""" Generate MobileNet-V2 network
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
Paper: https://arxiv.org/abs/1801.04381
arch_def = [
model_kwargs = dict(
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MobileNet-V3 model.
Ref impl: ?
Paper: https://arxiv.org/abs/1905.02244
channel_multiplier: multiplier to number of channels per layer.
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
# stage 5, 14x14in
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
model_kwargs = dict(
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_chamnet_v1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
""" Generate Chameleon Network (ChamNet)
Paper: https://arxiv.org/abs/1812.08934
Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
FIXME: this a bit of an educated guess based on trunkd def in maskrcnn_benchmark
arch_def = [
model_kwargs = dict(
num_features=1280, # no idea what this is? try mobile/mnasnet default?
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_chamnet_v2(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
""" Generate Chameleon Network (ChamNet)
Paper: https://arxiv.org/abs/1812.08934
Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
FIXME: this a bit of an educated guess based on trunk def in maskrcnn_benchmark
arch_def = [
model_kwargs = dict(
num_features=1280, # no idea what this is? try mobile/mnasnet default?
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
""" FBNet-C
Paper: https://arxiv.org/abs/1812.03443
Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper,
it was used to confirm some building block details
arch_def = [
['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'],
['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'],
['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'],
['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'],
model_kwargs = dict(
num_features=1984, # paper suggests this, but is not 100% clear
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates the Single-Path NAS model from search targeted for Pixel1 phone.
Paper: https://arxiv.org/abs/1904.02877
channel_multiplier: multiplier to number of channels per layer.
arch_def = [
# stage 0, 112x112 in
# stage 1, 112x112 in
# stage 2, 56x56 in
['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'],
# stage 3, 28x28 in
['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'],
# stage 4, 14x14in
['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'],
# stage 5, 14x14in
# stage 6, 7x7 in
model_kwargs = dict(
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
"""Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
channel_multiplier: multiplier to number of channels per layer
depth_multiplier: multiplier to number of repeats per stage
arch_def = [
model_kwargs = dict(
block_args=_decode_arch_def(arch_def, depth_multiplier),
num_features=_round_channels(1280, channel_multiplier, 8, None),
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
""" Creates an EfficientNet-EdgeTPU model
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
arch_def = [
# NOTE `fc` is present to override a mismatch between stem channels and in chs not
# present in other models
model_kwargs = dict(
block_args=_decode_arch_def(arch_def, depth_multiplier),
num_features=_round_channels(1280, channel_multiplier, 8, None),
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_efficientnet_condconv(
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
"""Creates an EfficientNet-CondConv model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
arch_def = [
# NOTE unlike official impl, this one uses `cc<x>` option where x is the base number of experts for each stage and
# the expert_multiplier increases that on a per-model basis as with depth/channel multipliers
model_kwargs = dict(
block_args=_decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
num_features=_round_channels(1280, channel_multiplier, 8, None),
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MixNet Small model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
Paper: https://arxiv.org/abs/1907.09595
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16'], # relu
# stage 1, 112x112 in
['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu
# stage 2, 56x56 in
['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
# stage 3, 28x28 in
['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish
# stage 4, 14x14in
['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
# stage 5, 14x14in
['ir_r1_k3.', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
# 7x7
model_kwargs = dict(
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MixNet Medium-Large model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
Paper: https://arxiv.org/abs/1907.09595
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c24'], # relu
# stage 1, 112x112 in
['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu
# stage 2, 56x56 in
['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
# stage 3, 28x28 in
['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish
# stage 4, 14x14in
['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
# stage 5, 14x14in
['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
# 7x7
model_kwargs = dict(
block_args=_decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def mnasnet_050(pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 0.5. """
model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs)
return model
def mnasnet_075(pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 0.75. """
model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs)
return model
def mnasnet_100(pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 1.0. """
model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mnasnet_b1(pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 1.0. """
return mnasnet_100(pretrained, **kwargs)
def mnasnet_140(pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 1.4 """
model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs)
return model
def semnasnet_050(pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 0.5 """
model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs)
return model
def semnasnet_075(pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 0.75. """
model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs)
return model
def semnasnet_100(pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mnasnet_a1(pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
return semnasnet_100(pretrained, **kwargs)
def semnasnet_140(pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs)
return model
def mnasnet_small(pretrained=False, **kwargs):
""" MNASNet Small, depth multiplier of 1.0. """
model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs)
return model
def mobilenetv1_100(pretrained=False, **kwargs):
""" MobileNet V1 """
model = _gen_mobilenet_v1('mobilenetv1_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mobilenetv2_100(pretrained=False, **kwargs):
""" MobileNet V2 """
model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_050(pretrained=False, **kwargs):
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_050', 0.5, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_075(pretrained=False, **kwargs):
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_075', 0.75, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_100(pretrained=False, **kwargs):
""" MobileNet V3 """
if pretrained:
# pretrained model trained with non-default BN epsilon
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
model = _gen_mobilenet_v3('mobilenetv3_100', 1.0, pretrained=pretrained, **kwargs)
return model
def fbnetc_100(pretrained=False, **kwargs):
""" FBNet-C """
if pretrained:
# pretrained model trained with non-default BN epsilon
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs)
return model
def chamnetv1_100(pretrained=False, **kwargs):
""" ChamNet """
model = _gen_chamnet_v1('chamnetv1_100', 1.0, pretrained=pretrained, **kwargs)
return model
def chamnetv2_100(pretrained=False, **kwargs):
""" ChamNet """
model = _gen_chamnet_v2('chamnetv2_100', 1.0, pretrained=pretrained, **kwargs)
return model
def spnasnet_100(pretrained=False, **kwargs):
""" Single-Path NAS Pixel1"""
model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs)
return model
def efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0 """
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet(
'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def efficientnet_b1(pretrained=False, **kwargs):
""" EfficientNet-B1 """
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet(
'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def efficientnet_b2(pretrained=False, **kwargs):
""" EfficientNet-B2 """
# NOTE for train, drop_rate should be 0.3
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet(
'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
def efficientnet_b3(pretrained=False, **kwargs):
""" EfficientNet-B3 """
# NOTE for train, drop_rate should be 0.3
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet(
'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def efficientnet_b4(pretrained=False, **kwargs):
""" EfficientNet-B4 """
# NOTE for train, drop_rate should be 0.4
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet(
'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
def efficientnet_b5(pretrained=False, **kwargs):
""" EfficientNet-B5 """
# NOTE for train, drop_rate should be 0.4
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet(
'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
return model
def efficientnet_b6(pretrained=False, **kwargs):
""" EfficientNet-B6 """
# NOTE for train, drop_rate should be 0.5
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet(
'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
return model
def efficientnet_b7(pretrained=False, **kwargs):
""" EfficientNet-B7 """
# NOTE for train, drop_rate should be 0.5
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet(
'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
return model
def efficientnet_es(pretrained=False, **kwargs):
""" EfficientNet-Edge Small. """
model = _gen_efficientnet_edge(
'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def efficientnet_em(pretrained=False, **kwargs):
""" EfficientNet-Edge-Medium. """
model = _gen_efficientnet_edge(
'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def efficientnet_el(pretrained=False, **kwargs):
""" EfficientNet-Edge-Large. """
model = _gen_efficientnet_edge(
'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 8 Experts """
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet_condconv(
'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 8 Experts """
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet_condconv(
'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
pretrained=pretrained, **kwargs)
return model
def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B1 w/ 8 Experts """
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet_condconv(
'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0. Tensorflow compatible variant """
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b1(pretrained=False, **kwargs):
""" EfficientNet-B1. Tensorflow compatible variant """
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b2(pretrained=False, **kwargs):
""" EfficientNet-B2. Tensorflow compatible variant """
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B3. Tensorflow compatible variant """
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b4(pretrained=False, **kwargs):
""" EfficientNet-B4. Tensorflow compatible variant """
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b5(pretrained=False, **kwargs):
""" EfficientNet-B5. Tensorflow compatible variant """
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b6(pretrained=False, **kwargs):
""" EfficientNet-B6. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b7(pretrained=False, **kwargs):
""" EfficientNet-B7. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_es(pretrained=False, **kwargs):
""" EfficientNet-Edge Small. Tensorflow compatible variant """
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_edge(
'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_em(pretrained=False, **kwargs):
""" EfficientNet-Edge-Medium. Tensorflow compatible variant """
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_edge(
'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_el(pretrained=False, **kwargs):
""" EfficientNet-Edge-Large. Tensorflow compatible variant """
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_edge(
'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_condconv(
'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_condconv(
'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_condconv(
'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
pretrained=pretrained, **kwargs)
return model
def mixnet_s(pretrained=False, **kwargs):
"""Creates a MixNet Small model.
model = _gen_mixnet_s(
'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def mixnet_m(pretrained=False, **kwargs):
"""Creates a MixNet Medium model.
model = _gen_mixnet_m(
'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def mixnet_l(pretrained=False, **kwargs):
"""Creates a MixNet Large model.
model = _gen_mixnet_m(
'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
return model
def mixnet_xl(pretrained=False, **kwargs):
"""Creates a MixNet Extra-Large model.
Not a paper spec, experimental def by RW w/ depth scaling.
model = _gen_mixnet_m(
'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
def mixnet_xxl(pretrained=False, **kwargs):
"""Creates a MixNet Double Extra Large model.
Not a paper spec, experimental def by RW w/ depth scaling.
# kwargs['drop_connect_rate'] = 0.2
model = _gen_mixnet_m(
'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs)
return model
def tf_mixnet_s(pretrained=False, **kwargs):
"""Creates a MixNet Small model. Tensorflow compatible variant
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mixnet_s(
'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_mixnet_m(pretrained=False, **kwargs):
"""Creates a MixNet Medium model. Tensorflow compatible variant
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mixnet_m(
'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_mixnet_l(pretrained=False, **kwargs):
"""Creates a MixNet Large model. Tensorflow compatible variant
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mixnet_m(
'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
return model
def gen_efficientnet_model_names():
return set(_models)