From 506df0e3d0136da4063b2a2881958e57cf43c784 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 22 Oct 2019 23:42:04 -0700 Subject: [PATCH] Add CondConv support for EfficientNet into WIP for GenEfficientNet Feature extraction setup --- timm/models/conv2d_helpers.py | 120 --- timm/models/conv2d_layers.py | 255 ++++++ timm/models/gen_efficientnet.py | 1295 +++++++++++++++++-------------- timm/models/helpers.py | 3 +- 4 files changed, 978 insertions(+), 695 deletions(-) delete mode 100644 timm/models/conv2d_helpers.py create mode 100644 timm/models/conv2d_layers.py diff --git a/timm/models/conv2d_helpers.py b/timm/models/conv2d_helpers.py deleted file mode 100644 index 674eadca..00000000 --- a/timm/models/conv2d_helpers.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math - - -def _is_static_pad(kernel_size, stride=1, dilation=1, **_): - return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 - - -def _get_padding(kernel_size, stride=1, dilation=1, **_): - padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 - return padding - - -def _calc_same_pad(i, k, s, d): - return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) - - -def _split_channels(num_chan, num_groups): - split = [num_chan // num_groups for _ in range(num_groups)] - split[0] += num_chan - sum(split) - return split - - -class Conv2dSame(nn.Conv2d): - """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions - """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True): - super(Conv2dSame, self).__init__( - in_channels, out_channels, kernel_size, stride, 0, dilation, - groups, bias) - - def forward(self, x): - ih, iw = x.size()[-2:] - kh, kw = self.weight.size()[-2:] - pad_h = _calc_same_pad(ih, kh, self.stride[0], self.dilation[0]) - pad_w = _calc_same_pad(iw, kw, self.stride[1], self.dilation[1]) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2]) - return F.conv2d(x, self.weight, self.bias, self.stride, - self.padding, self.dilation, self.groups) - - -def conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): - padding = kwargs.pop('padding', '') - kwargs.setdefault('bias', False) - if isinstance(padding, str): - # for any string padding, the padding will be calculated for you, one of three ways - padding = padding.lower() - if padding == 'same': - # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact - if _is_static_pad(kernel_size, **kwargs): - # static case, no extra overhead - padding = _get_padding(kernel_size, **kwargs) - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) - else: - # dynamic padding - return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) - elif padding == 'valid': - # 'VALID' padding, same as padding=0 - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs) - else: - # Default to PyTorch style 'same'-ish symmetric padding - padding = _get_padding(kernel_size, **kwargs) - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) - else: - # padding was specified as a number or pair - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) - - -class MixedConv2d(nn.Module): - """ Mixed Grouped Convolution - Based on MDConv and GroupedConv in MixNet impl: - https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py - """ - - def __init__(self, in_channels, out_channels, kernel_size=3, - stride=1, padding='', dilated=False, depthwise=False, **kwargs): - super(MixedConv2d, self).__init__() - - kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] - num_groups = len(kernel_size) - in_splits = _split_channels(in_channels, num_groups) - out_splits = _split_channels(out_channels, num_groups) - for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): - d = 1 - # FIXME make compat with non-square kernel/dilations/strides - if stride == 1 and dilated: - d, k = (k - 1) // 2, 3 - conv_groups = out_ch if depthwise else 1 - # use add_module to keep key space clean - self.add_module( - str(idx), - conv2d_pad( - in_ch, out_ch, k, stride=stride, - padding=padding, dilation=d, groups=conv_groups, **kwargs) - ) - self.splits = in_splits - - def forward(self, x): - x_split = torch.split(x, self.splits, 1) - x_out = [c(x) for x, c in zip(x_split, self._modules.values())] - x = torch.cat(x_out, 1) - return x - - -# helper method -def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): - assert 'groups' not in kwargs # only use 'depthwise' bool arg - if isinstance(kernel_size, list): - # We're going to use only lists for defining the MixedConv2d kernel groups, - # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. - return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) - else: - depthwise = kwargs.pop('depthwise', False) - groups = out_chs if depthwise else 1 - return conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) - diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py new file mode 100644 index 00000000..cd52b885 --- /dev/null +++ b/timm/models/conv2d_layers.py @@ -0,0 +1,255 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch._six import container_abcs +from itertools import repeat +from functools import partial +import numpy as np +import math + + +def _ntuple(n): + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def _get_padding(kernel_size, stride=1, dilation=1, **_): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def _calc_same_pad(i, k, s, d): + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +# pylint: disable=unused-argument +def conv2d_same(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1): + ih, iw = x.size()[-2:] + kh, kw = weight.size()[-2:] + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def get_padding_value(padding, kernel_size, **kwargs): + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if _is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = _get_padding(kernel_size, **kwargs) + else: + # dynamic padding + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = _get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + +class MixedConv2d(nn.Module): + """ Mixed Grouped Convolution + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, mixed_dilated=False, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + d = dilation + # FIXME make compat with non-square kernel/dilations/strides + if stride == 1 and mixed_dilated: + d, k = (k - 1) // 2, 3 + conv_groups = out_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=d, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x) for x, c in zip(x_split, self._modules.values())] + x = torch.cat(x_out, 1) + return x + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditional Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.conv_fn = conv2d_same if is_padding_dynamic else F.conv2d + self.padding = _pair(padding_val) + self.dilation = _pair(dilation) + self.transposed = False + self.output_padding = _pair(0) + self.groups = groups + self.padding_mode = 'zero' + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + # FIXME I haven't tested bias yet + if bias: + self.bias_shape = (self.out_channels,) + condconv_bias_shape = (self.num_experts, self.out_channels) + self.bias = torch.nn.Parameter(torch.Tensor(condconv_bias_shape)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + # FIXME once I'm satisfied this works, remove the looping path? + self._use_groups = True # use groups for parallel per-batch-element kernel convolution + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + # FIXME bias not tested + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + weight = torch.matmul(routing_weights, self.weight) + bias = torch.matmul(routing_weights, self.bias) if self.bias is not None else None + B, C, H, W = x.shape + if self._use_groups: + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + x = x.view(1, B * C, H, W) + out = self.conv_fn( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + else: + x = torch.split(x, 1, 0) + weight = torch.split(weight, 1, 0) + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = torch.split(bias, 1, 0) + else: + bias = [None] * B + out = [] + for xi, wi, bi in zip(x, weight, bias): + wi = wi.view(*self.weight_shape) + if bi is not None: + bi = bi.view(*self.bias_shape) + out.append(self.conv_fn( + xi, wi, bi, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups)) + out = torch.cat(out, 0) + return out + + +# helper method +def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): + assert 'groups' not in kwargs # only use 'depthwise' bool arg + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_chs if depthwise else 1 + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + create_fn = CondConv2d + else: + create_fn = create_conv2d_pad + return create_fn(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index a7191025..e51bab2a 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -19,16 +19,19 @@ import math import re import logging from copy import deepcopy +from functools import partial +from collections import OrderedDict, defaultdict import torch import torch.nn as nn import torch.nn.functional as F -from .registry import register_model +from timm.models.activations import Swish, sigmoid, HardSwish, hard_sigmoid +from .registry import register_model, model_entrypoint from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d -from .conv2d_helpers import select_conv2d -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .conv2d_layers import select_conv2d +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD __all__ = ['GenEfficientNet'] @@ -96,6 +99,9 @@ default_cfgs = { 'efficientnet_el': _cfg( url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_cc_b0_4e': _cfg(url=''), + 'efficientnet_cc_b0_8e': _cfg(url=''), + 'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), 'tf_efficientnet_b0': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', input_size=(3, 224, 224)), @@ -132,6 +138,16 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_cc_b0_4e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b0_8e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b1_8e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), 'mixnet_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'), 'mixnet_m': _cfg( @@ -150,7 +166,7 @@ default_cfgs = { } -_DEBUG = False +_DEBUG = True # Default args for PyTorch BN impl _BN_MOMENTUM_PT_DEFAULT = 0.1 @@ -201,7 +217,7 @@ def _parse_ksize(ss): return [int(k) for k in ss.split('.')] -def _decode_block_str(block_str, depth_multiplier=1.0): +def _decode_block_str(block_str): """ Decode block definition string Gets a list of block arg (dicts) through a string notation of arguments. @@ -241,13 +257,13 @@ def _decode_block_str(block_str, depth_multiplier=1.0): key = op[0] v = op[1:] if v == 're': - value = F.relu + value = nn.ReLU elif v == 'r6': - value = F.relu6 + value = nn.ReLU6 elif v == 'hs': - value = hard_swish + value = HardSwish elif v == 'sw': - value = swish + value = Swish else: continue options[key] = value @@ -258,8 +274,8 @@ def _decode_block_str(block_str, depth_multiplier=1.0): key, value = splits[:2] options[key] = value - # if act_fn is None, the model default (passed to model init) will be used - act_fn = options['n'] if 'n' in options else None + # if act_layer is None, the model default (passed to model init) will be used + act_layer = options['n'] if 'n' in options else None exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def @@ -276,8 +292,9 @@ def _decode_block_str(block_str, depth_multiplier=1.0): exp_ratio=float(options['e']), se_ratio=float(options['se']) if 'se' in options else None, stride=int(options['s']), - act_fn=act_fn, + act_layer=act_layer, noskip=noskip, + num_experts=int(options['cc']) if 'cc' in options else 0 ) elif block_type == 'ds' or block_type == 'dsa': block_args = dict( @@ -287,7 +304,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0): out_chs=int(options['c']), se_ratio=float(options['se']) if 'se' in options else None, stride=int(options['s']), - act_fn=act_fn, + act_layer=act_layer, pw_act=block_type == 'dsa', noskip=block_type == 'dsa' or noskip, ) @@ -301,7 +318,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0): fake_in_chs=fake_in_chs, se_ratio=float(options['se']) if 'se' in options else None, stride=int(options['s']), - act_fn=act_fn, + act_layer=act_layer, noskip=noskip, ) elif block_type == 'cn': @@ -310,7 +327,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0): kernel_size=int(options['k']), out_chs=int(options['c']), stride=int(options['s']), - act_fn=act_fn, + act_layer=act_layer, ) else: assert False, 'Unknown block type (%s)' % block_type @@ -356,7 +373,7 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c return sa_scaled -def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'): +def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1): arch_args = [] for stack_idx, block_strings in enumerate(arch_def): assert isinstance(block_strings, list) @@ -365,6 +382,8 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'): for block_str in block_strings: assert isinstance(block_str, str) ba, rep = _decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier stack_args.append(ba) repeats.append(rep) arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) @@ -437,61 +456,67 @@ class _BlockBuilder: """ def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - pad_type='', act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False, - bn_args=_BN_ARGS_PT, drop_connect_rate=0., verbose=False): + output_stride=32, pad_type='', act_layer=None, se_gate_fn=sigmoid, se_reduce_mid=False, + norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, drop_connect_rate=0., feature_location='', + verbose=False): self.channel_multiplier = channel_multiplier self.channel_divisor = channel_divisor self.channel_min = channel_min + self.output_stride = output_stride self.pad_type = pad_type - self.act_fn = act_fn + self.act_layer = act_layer self.se_gate_fn = se_gate_fn self.se_reduce_mid = se_reduce_mid - self.bn_args = bn_args + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs self.drop_connect_rate = drop_connect_rate + self.feature_location = feature_location + assert feature_location in ('pre_pwl', 'post_exp', '') self.verbose = verbose - # updated during build + # state updated during build, consumed by model self.in_chs = None - self.block_idx = 0 - self.block_count = 0 + self.features = OrderedDict() def _round_channels(self, chs): return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) - def _make_block(self, ba): + def _make_block(self, ba, block_idx, block_count): + drop_connect_rate = self.drop_connect_rate * block_idx / block_count bt = ba.pop('block_type') ba['in_chs'] = self.in_chs ba['out_chs'] = self._round_channels(ba['out_chs']) if 'fake_in_chs' in ba and ba['fake_in_chs']: # FIXME this is a hack to work around mismatch in origin impl input filters ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) - ba['bn_args'] = self.bn_args + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs ba['pad_type'] = self.pad_type # block act fn overrides the model default - ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn - assert ba['act_fn'] is not None + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None if bt == 'ir': - ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['drop_connect_rate'] = drop_connect_rate ba['se_gate_fn'] = self.se_gate_fn ba['se_reduce_mid'] = self.se_reduce_mid if self.verbose: - logging.info(' InvertedResidual {}, Args: {}'.format(self.block_idx, str(ba))) + logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba))) block = InvertedResidual(**ba) elif bt == 'ds' or bt == 'dsa': - ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['drop_connect_rate'] = drop_connect_rate if self.verbose: - logging.info(' DepthwiseSeparable {}, Args: {}'.format(self.block_idx, str(ba))) + logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba))) block = DepthwiseSeparableConv(**ba) elif bt == 'er': - ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['drop_connect_rate'] = drop_connect_rate ba['se_gate_fn'] = self.se_gate_fn ba['se_reduce_mid'] = self.se_reduce_mid if self.verbose: - logging.info(' EdgeResidual {}, Args: {}'.format(self.block_idx, str(ba))) + logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba))) block = EdgeResidual(**ba) elif bt == 'cn': if self.verbose: - logging.info(' ConvBnAct {}, Args: {}'.format(self.block_idx, str(ba))) + logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba))) block = ConvBnAct(**ba) else: assert False, 'Uknkown block type (%s) while building model.' % bt @@ -499,46 +524,96 @@ class _BlockBuilder: return block - def _make_stack(self, stack_args): - blocks = [] - # each stack (stage) contains a list of block arguments - for i, ba in enumerate(stack_args): - if self.verbose: - logging.info(' Block: {}'.format(i)) - if i >= 1: - # only the first block in any stack can have a stride > 1 - ba['stride'] = 1 - block = self._make_block(ba) - blocks.append(block) - self.block_idx += 1 # incr global idx (across all stacks) - return nn.Sequential(*blocks) - - def __call__(self, in_chs, block_args): + def __call__(self, in_chs, model_block_args): """ Build the blocks Args: in_chs: Number of input-channels passed to first block - block_args: A list of lists, outer list defines stages, inner + model_block_args: A list of lists, outer list defines stages, inner list contains strings defining block configuration(s) Return: List of block stacks (each stack wrapped in nn.Sequential) """ if self.verbose: - logging.info('Building model trunk with %d stages...' % len(block_args)) + logging.info('Building model trunk with %d stages...' % len(model_block_args)) self.in_chs = in_chs - self.block_count = sum([len(x) for x in block_args]) - self.block_idx = 0 - blocks = [] + total_block_count = sum([len(x) for x in model_block_args]) + total_block_idx = 0 + current_stride = 2 + current_dilation = 1 + feature_idx = 0 + stages = [] # outer list of block_args defines the stacks ('stages' by some conventions) - for stack_idx, stack in enumerate(block_args): + for stage_idx, stage_block_args in enumerate(model_block_args): + last_stack = stage_idx == (len(model_block_args) - 1) if self.verbose: - logging.info('Stack: {}'.format(stack_idx)) - assert isinstance(stack, list) - stack = self._make_stack(stack) - blocks.append(stack) - return blocks - - -def _initialize_weight_goog(m): + logging.info('Stack: {}'.format(stage_idx)) + assert isinstance(stage_block_args, list) + + blocks = [] + # each stack (stage) contains a list of block arguments + for block_idx, block_args in enumerate(stage_block_args): + last_block = block_idx == (len(stage_block_args) - 1) + extract_features = '' # No features extracted + if self.verbose: + logging.info(' Block: {}'.format(block_idx)) + + # Sort out stride, dilation, and feature extraction details + assert block_args['stride'] in (1, 2) + if block_idx >= 1: + # only the first block in any stack can have a stride > 1 + block_args['stride'] = 1 + + do_extract = False + if self.feature_location == 'pre_pwl': + if last_block: + next_stage_idx = stage_idx + 1 + if next_stage_idx >= len(model_block_args): + do_extract = True + else: + do_extract = model_block_args[next_stage_idx][0]['stride'] > 1 + elif self.feature_location == 'post_exp': + if block_args['stride'] > 1 or (last_stack and last_block) : + do_extract = True + if do_extract: + extract_features = self.feature_location + + next_dilation = current_dilation + if block_args['stride'] > 1: + next_output_stride = current_stride * block_args['stride'] + if next_output_stride > self.output_stride: + next_dilation = current_dilation * block_args['stride'] + block_args['stride'] = 1 + if self.verbose: + logging.info(' Converting stride to dilation to maintain output_stride=={}'.format( + self.output_stride)) + else: + current_stride = next_output_stride + block_args['dilation'] = current_dilation + if next_dilation != current_dilation: + current_dilation = next_dilation + + # create the block + block = self._make_block(block_args, total_block_idx, total_block_count) + blocks.append(block) + + # stash feature module name and channel info for model feature extraction + if extract_features: + feature_module = block.feature_module(extract_features) + if feature_module: + feature_module = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_module + feature_channels = block.feature_channels(extract_features) + self.features[feature_idx] = dict( + name=feature_module, + num_chs=feature_channels + ) + feature_idx += 1 + + total_block_idx += 1 # incr global block idx (across all stacks) + stages.append(nn.Sequential(*blocks)) + return stages + + +def _init_weight_goog(m): # weight init as per Tensorflow Official impl # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py if isinstance(m, nn.Conv2d): @@ -556,7 +631,7 @@ def _initialize_weight_goog(m): m.bias.data.zero_() -def _initialize_weight_default(m): +def _init_weight_default(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): @@ -601,19 +676,19 @@ class ChannelShuffle(nn.Module): class SqueezeExcite(nn.Module): - def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu, gate_fn=sigmoid): + def __init__(self, in_chs, reduce_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid): super(SqueezeExcite, self).__init__() - self.act_fn = act_fn self.gate_fn = gate_fn reduced_chs = reduce_chs or in_chs self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) def forward(self, x): # NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) x_se = self.conv_reduce(x_se) - x_se = self.act_fn(x_se, inplace=True) + x_se = self.act1(x_se) x_se = self.conv_expand(x_se) x = x * self.gate_fn(x_se) return x @@ -621,17 +696,24 @@ class SqueezeExcite(nn.Module): class ConvBnAct(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, - stride=1, pad_type='', act_fn=F.relu, bn_args=_BN_ARGS_PT): + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT,): super(ConvBnAct, self).__init__() assert stride in [1, 2] - self.act_fn = act_fn - self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) - self.bn1 = nn.BatchNorm2d(out_chs, **bn_args) + self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn1 = norm_layer(out_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + def feature_module(self, location): + return 'act1' + + def feature_channels(self, location): + return self.conv.out_channels def forward(self, x): x = self.conv(x) x = self.bn1(x) - x = self.act_fn(x, inplace=True) + x = self.act1(x) return x @@ -639,29 +721,41 @@ class EdgeResidual(nn.Module): """ Residual block with expansion convolution followed by pointwise-linear w/ stride""" def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, - stride=1, pad_type='', act_fn=F.relu, noskip=False, pw_kernel_size=1, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid, - bn_args=_BN_ARGS_PT, drop_connect_rate=0.): + norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, drop_connect_rate=0.): super(EdgeResidual, self).__init__() mid_chs = int(fake_in_chs * exp_ratio) if fake_in_chs > 0 else int(in_chs * exp_ratio) self.has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip - self.act_fn = act_fn self.drop_connect_rate = drop_connect_rate # Expansion convolution self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) - self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) # Squeeze-and-excitation if self.has_se: se_base_chs = mid_chs if se_reduce_mid else in_chs self.se = SqueezeExcite( - mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) + mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_layer=act_layer, gate_fn=se_gate_fn) # Point-wise linear projection - self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) - self.bn2 = nn.BatchNorm2d(out_chs, **bn_args) + self.conv_pwl = select_conv2d( + mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + + def feature_module(self, location): + if location == 'post_exp': + return 'act1' + return 'conv_pwl' + + def feature_channels(self, location): + if location == 'post_exp': + return self.conv_exp.out_channels + # location == 'pre_pw' + return self.conv_pwl.in_channels def forward(self, x): residual = x @@ -669,7 +763,7 @@ class EdgeResidual(nn.Module): # Expansion convolution x = self.conv_exp(x) x = self.bn1(x) - x = self.act_fn(x, inplace=True) + x = self.act1(x) # Squeeze-and-excitation if self.has_se: @@ -693,44 +787,50 @@ class DepthwiseSeparableConv(nn.Module): factor of 1.0. This is an alternative to having a IR with an optional first pw conv. """ def __init__(self, in_chs, out_chs, dw_kernel_size=3, - stride=1, pad_type='', act_fn=F.relu, noskip=False, - pw_kernel_size=1, pw_act=False, - se_ratio=0., se_gate_fn=sigmoid, - bn_args=_BN_ARGS_PT, drop_connect_rate=0.): + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid, + norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, drop_connect_rate=0.): super(DepthwiseSeparableConv, self).__init__() assert stride in [1, 2] self.has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv - self.act_fn = act_fn self.drop_connect_rate = drop_connect_rate self.conv_dw = select_conv2d( - in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) - self.bn1 = nn.BatchNorm2d(in_chs, **bn_args) + in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) + self.bn1 = norm_layer(in_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) # Squeeze-and-excitation if self.has_se: self.se = SqueezeExcite( - in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) + in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_layer=act_layer, gate_fn=se_gate_fn) self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) - self.bn2 = nn.BatchNorm2d(out_chs, **bn_args) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() + + def feature_module(self, location): + # no expansion in this block, pre pw only feature extraction point + return 'conv_pw' + + def feature_channels(self, location): + return self.conv_pw.in_channels def forward(self, x): residual = x x = self.conv_dw(x) x = self.bn1(x) - x = self.act_fn(x, inplace=True) + x = self.act1(x) if self.has_se: x = self.se(x) x = self.conv_pw(x) x = self.bn2(x) - if self.has_pw_act: - x = self.act_fn(x, inplace=True) + x = self.act2(x) if self.has_residual: if self.drop_connect_rate > 0.: @@ -740,67 +840,87 @@ class DepthwiseSeparableConv(nn.Module): class InvertedResidual(nn.Module): - """ Inverted residual block w/ optional SE""" + """ Inverted residual block w/ optional SE and CondConv routing""" def __init__(self, in_chs, out_chs, dw_kernel_size=3, - stride=1, pad_type='', act_fn=F.relu, noskip=False, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid, - shuffle_type=None, bn_args=_BN_ARGS_PT, drop_connect_rate=0.): + norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, + num_experts=0, drop_connect_rate=0.): super(InvertedResidual, self).__init__() mid_chs = int(in_chs * exp_ratio) self.has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip - self.act_fn = act_fn self.drop_connect_rate = drop_connect_rate - # Point-wise expansion - self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) - self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args) + self.num_experts = num_experts + extra_args = dict() + if num_experts > 0: + extra_args = dict(num_experts=self.num_experts) + self.routing_fn = nn.Linear(in_chs, self.num_experts) + self.routing_act = torch.sigmoid - self.shuffle_type = shuffle_type - if shuffle_type is not None and isinstance(exp_kernel_size, list): - self.shuffle = ChannelShuffle(len(exp_kernel_size)) + # Point-wise expansion + self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **extra_args) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) # Depth-wise convolution self.conv_dw = select_conv2d( - mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) - self.bn2 = nn.BatchNorm2d(mid_chs, **bn_args) + mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **extra_args) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) # Squeeze-and-excitation if self.has_se: se_base_chs = mid_chs if se_reduce_mid else in_chs self.se = SqueezeExcite( - mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) + mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_layer=act_layer, gate_fn=se_gate_fn) # Point-wise linear projection - self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) - self.bn3 = nn.BatchNorm2d(out_chs, **bn_args) + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **extra_args) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def feature_module(self, location): + if location == 'post_exp': + return 'act1' + return 'conv_pwl' + + def feature_channels(self, location): + if location == 'post_exp': + return self.conv_pw.out_channels + # location == 'pre_pw' + return self.conv_pwl.in_channels def forward(self, x): residual = x + conv_pw, conv_dw, conv_pwl = self.conv_pw, self.conv_dw, self.conv_pwl + if self.num_experts > 0: + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + routing_weights = self.routing_act(self.routing_fn(pooled_inputs)) + conv_pw = partial(self.conv_pw, routing_weights=routing_weights) + conv_dw = partial(self.conv_dw, routing_weights=routing_weights) + conv_pwl = partial(self.conv_pwl, routing_weights=routing_weights) + # Point-wise expansion - x = self.conv_pw(x) + x = conv_pw(x) x = self.bn1(x) - x = self.act_fn(x, inplace=True) - - # FIXME haven't tried this yet - # for channel shuffle when using groups with pointwise convs as per FBNet variants - if self.shuffle_type == "mid": - x = self.shuffle(x) + x = self.act1(x) # Depth-wise convolution - x = self.conv_dw(x) + x = conv_dw(x) x = self.bn2(x) - x = self.act_fn(x, inplace=True) + x = self.act2(x) # Squeeze-and-excitation if self.has_se: x = self.se(x) # Point-wise linear projection - x = self.conv_pwl(x) + x = conv_pwl(x) x = self.bn3(x) if self.has_residual: @@ -808,12 +928,52 @@ class InvertedResidual(nn.Module): x = drop_connect(x, self.training, self.drop_connect_rate) x += residual - # NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants + return x + + +class _GenEfficientNet(nn.Module): + """ Generic EfficientNet Base + """ + + def __init__(self, block_args, in_chans=3, stem_size=32, + channel_multiplier=1.0, channel_divisor=8, channel_min=None, + output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_gate_fn=sigmoid, se_reduce_mid=False, norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, + feature_location='pre_pwl'): + super(_GenEfficientNet, self).__init__() + self.drop_rate = drop_rate + self._in_chs = in_chans + + # Stem + stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size + + # Middle stages (IR/ER/DS Blocks) + builder = _BlockBuilder( + channel_multiplier, channel_divisor, channel_min, + output_stride, pad_type, act_layer, se_gate_fn, se_reduce_mid, + norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.feature_info = builder.features + self._in_chs = builder.in_chs + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + return nn.Sequential(*layers) + def forward(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) return x -class GenEfficientNet(nn.Module): +class GenEfficientNet(_GenEfficientNet): """ Generic EfficientNet An implementation of efficient network architectures, in many cases mobile optimized networks: @@ -828,46 +988,77 @@ class GenEfficientNet(nn.Module): * MixNet S, M, L """ - def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, + def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - pad_type='', act_fn=F.relu, drop_rate=0., drop_connect_rate=0., - se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=_BN_ARGS_PT, + pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_gate_fn=sigmoid, se_reduce_mid=False, + norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, global_pool='avg', head_conv='default', weight_init='goog'): - super(GenEfficientNet, self).__init__() + self.num_classes = num_classes - self.drop_rate = drop_rate - self.act_fn = act_fn self.num_features = num_features - - stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) - self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) - self.bn1 = nn.BatchNorm2d(stem_size, **bn_args) - in_chs = stem_size - - builder = _BlockBuilder( - channel_multiplier, channel_divisor, channel_min, - pad_type, act_fn, se_gate_fn, se_reduce_mid, - bn_args, drop_connect_rate, verbose=_DEBUG) - self.blocks = nn.Sequential(*builder(in_chs, block_args)) - in_chs = builder.in_chs - - if not head_conv or head_conv == 'none': - self.efficient_head = False - self.conv_head = None - assert in_chs == self.num_features - else: - self.efficient_head = head_conv == 'efficient' - self.conv_head = select_conv2d(in_chs, self.num_features, 1, padding=pad_type) - self.bn2 = None if self.efficient_head else nn.BatchNorm2d(self.num_features, **bn_args) - - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + super(GenEfficientNet, self).__init__( # FIXME it would be nice if Python made this nicer + block_args, in_chans=in_chans, stem_size=stem_size, + pad_type=pad_type, act_layer=act_layer, drop_rate=drop_rate, drop_connect_rate=drop_connect_rate, + channel_multiplier=channel_multiplier, channel_divisor=channel_divisor, channel_min=channel_min, + se_gate_fn=se_gate_fn, se_reduce_mid=se_reduce_mid, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + + # Head + Pooling + self.conv_head = None + self.global_pool = None + self.act2 = None + self.forward_head = None + self.head_conv = head_conv + if head_conv == 'efficient': + self._create_head_efficient(global_pool, pad_type, act_layer) + elif head_conv == 'default': + self._create_head_default(global_pool, pad_type, act_layer, norm_layer, norm_kwargs) + + # Classifier self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes) for m in self.modules(): if weight_init == 'goog': - _initialize_weight_goog(m) + _init_weight_goog(m) else: - _initialize_weight_default(m) + _init_weight_default(m) + + def _create_head_default(self, global_pool, pad_type, act_layer, norm_layer, norm_kwargs): + self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type) + self.bn2 = norm_layer(self.num_features, **norm_kwargs) + self.act2 = act_layer(inplace=True) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + + def _create_head_efficient(self, global_pool, pad_type, act_layer): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type) + self.act2 = act_layer(inplace=True) + + def _forward_head_default(self, x): + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + return x + + def _forward_head_efficient(self, x): + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + if self.head_conv == 'efficient': + layers.extend([self.global_pool, self.bn2, self.act2]) + else: + layers.extend([self.conv_head, self.bn2, self.act2]) + if self.global_pool is not None: + layers.append(self.global_pool) + #append flatten layer + layers.append(self.classifier) + return nn.Sequential(*layers) + def get_classifier(self): return self.classifier @@ -882,38 +1073,121 @@ class GenEfficientNet(nn.Module): else: self.classifier = None - def forward_features(self, x, pool=True): - x = self.conv_stem(x) - x = self.bn1(x) - x = self.act_fn(x, inplace=True) - x = self.blocks(x) - if self.efficient_head: - # efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv - x = self.global_pool(x) # always need to pool here regardless of flag - x = self.conv_head(x) - # no BN - x = self.act_fn(x, inplace=True) - if pool: - # expect flattened output if pool is true, otherwise keep dim - x = x.view(x.size(0), -1) - else: - if self.conv_head is not None: - x = self.conv_head(x) - x = self.bn2(x) - x = self.act_fn(x, inplace=True) - if pool: - x = self.global_pool(x) - x = x.view(x.size(0), -1) + def forward_features(self, x): + x = super(GenEfficientNet, self).forward(x) + if self.head_conv == 'efficient': + x = self._forward_head_efficient(x) + elif self.head_conv == 'default': + x = self._forward_head_default(x) return x def forward(self, x): x = self.forward_features(x) + if self.global_pool is not None and x.shape[-1] > 1 or x.shape[-2] > 1: + x = self.global_pool(x) + x = x.flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x) -def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs): +class GenEfficientNetFeatures(_GenEfficientNet): + """ Generic EfficientNet Feature Extractor + """ + + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', + in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, + output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_gate_fn=sigmoid, se_reduce_mid=False, norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, + weight_init='goog'): + + # validate and modify block arguments and out indices for feature extraction + num_stages = max(out_indices) + 1 # FIXME reduce num stages created if not needed + #assert len(block_args) >= num_stages - 1 + #block_args = block_args[:num_stages - 1] + + super(GenEfficientNetFeatures, self).__init__( # FIXME it would be nice if Python made this nicer + block_args, in_chans=in_chans, stem_size=stem_size, + output_stride=output_stride, pad_type=pad_type, act_layer=act_layer, + drop_rate=drop_rate, drop_connect_rate=drop_connect_rate, feature_location=feature_location, + channel_multiplier=channel_multiplier, channel_divisor=channel_divisor, channel_min=channel_min, + se_gate_fn=se_gate_fn, se_reduce_mid=se_reduce_mid, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + + for m in self.modules(): + if weight_init == 'goog': + _init_weight_goog(m) + else: + _init_weight_default(m) + + if _DEBUG: + for k, v in self.feature_info.items(): + print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) + hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward' + hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices] + self._feature_outputs = None + self._register_hooks(hooks) + + def _collect_output_hook(self, name, *args): + x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][name] = x + + def _get_output(self, device): + output = tuple(self._feature_outputs[device].values())[::-1] + self._feature_outputs[device] = OrderedDict() + return output + + def _register_hooks(self, hooks): + # setup feature hooks + modules = {k: v for k, v in self.named_modules()} + for h in hooks: + hook_name = h['name'] + m = modules[hook_name] + hook_fn = partial(self._collect_output_hook, hook_name) + if h['type'] == 'forward_pre': + m.register_forward_pre_hook(hook_fn) + else: + m.register_forward_hook(hook_fn) + self._feature_outputs = defaultdict(OrderedDict) + + def feature_channels(self, idx=None): + if isinstance(idx, int): + return self.feature_info[idx]['num_chs'] + return [self.feature_info[i]['num_chs'] for i in self.out_indices] + + def forward(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + self.blocks(x) + return self._get_output(x.device) + + +def _create_model(model_kwargs, default_cfg, pretrained=False): + if model_kwargs.pop('features_only', False): + load_strict = False + model_kwargs.pop('num_classes', 0) + model_kwargs.pop('num_features', 0) + model_kwargs.pop('head_conv', None) + model_class = GenEfficientNetFeatures + else: + load_strict = True + model_class = GenEfficientNet + + model = model_class(**model_kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=model_kwargs.get('num_classes', 0), + in_chans=model_kwargs.get('in_chans', 3), + strict=load_strict) + return model + + +def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a mnasnet-a1 model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet @@ -938,18 +1212,18 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs): # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320'], ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), + norm_kwargs=_resolve_bn_args(kwargs), **kwargs ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs): +def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a mnasnet-b1 model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet @@ -974,18 +1248,18 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs): # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320_noskip'] ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), + norm_kwargs=_resolve_bn_args(kwargs), **kwargs ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs): +def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a mnasnet-b1 model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet @@ -1003,18 +1277,18 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs): ['ir_r3_k5_s2_e6_c88_se0.25'], ['ir_r1_k3_s1_e6_c144'] ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def), stem_size=8, channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), + norm_kwargs=_resolve_bn_args(kwargs), **kwargs ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs): +def _gen_mobilenet_v1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """ Generate MobileNet-V1 network Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py Paper: https://arxiv.org/abs/1801.04381 @@ -1026,21 +1300,21 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs): ['dsa_r6_k3_s2_c512'], ['dsa_r2_k3_s2_c1024'], ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def), stem_size=32, num_features=1024, channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - act_fn=F.relu6, + norm_kwargs=_resolve_bn_args(kwargs), + act_layer=nn.ReLU6, head_conv='none', **kwargs - ) + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs): +def _gen_mobilenet_v2(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """ Generate MobileNet-V2 network Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py Paper: https://arxiv.org/abs/1801.04381 @@ -1054,19 +1328,19 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs): ['ir_r3_k3_s2_e6_c160'], ['ir_r1_k3_s1_e6_c320'], ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - act_fn=F.relu6, + norm_kwargs=_resolve_bn_args(kwargs), + act_layer=nn.ReLU6, **kwargs ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs): +def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MobileNet-V3 model. Ref impl: ? @@ -1091,22 +1365,22 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs): # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], # hard-swish ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def), stem_size=16, channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - act_fn=hard_swish, + norm_kwargs=_resolve_bn_args(kwargs), + act_layer=HardSwish, se_gate_fn=hard_sigmoid, se_reduce_mid=True, head_conv='efficient', - **kwargs + **kwargs, ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs): +def _gen_chamnet_v1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """ Generate Chameleon Network (ChamNet) Paper: https://arxiv.org/abs/1812.08934 @@ -1123,19 +1397,19 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs): ['ir_r4_k3_s2_e7_c152'], ['ir_r1_k3_s1_e10_c104'], ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def), stem_size=32, num_features=1280, # no idea what this is? try mobile/mnasnet default? channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), + norm_kwargs=_resolve_bn_args(kwargs), **kwargs ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs): +def _gen_chamnet_v2(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """ Generate Chameleon Network (ChamNet) Paper: https://arxiv.org/abs/1812.08934 @@ -1152,19 +1426,19 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs): ['ir_r6_k3_s2_e2_c152'], ['ir_r1_k3_s1_e6_c112'], ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def), stem_size=32, num_features=1280, # no idea what this is? try mobile/mnasnet default? channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), + norm_kwargs=_resolve_bn_args(kwargs), **kwargs ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs): +def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """ FBNet-C Paper: https://arxiv.org/abs/1812.03443 @@ -1182,19 +1456,19 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs): ['ir_r4_k5_s2_e6_c184'], ['ir_r1_k3_s1_e6_c352'], ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def), stem_size=16, num_features=1984, # paper suggests this, but is not 100% clear channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), + norm_kwargs=_resolve_bn_args(kwargs), **kwargs ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs): +def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates the Single-Path NAS model from search targeted for Pixel1 phone. Paper: https://arxiv.org/abs/1904.02877 @@ -1218,18 +1492,18 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs): # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320_noskip'] ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), + norm_kwargs=_resolve_bn_args(kwargs), **kwargs ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs): +def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): """Creates an EfficientNet model. Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py @@ -1260,21 +1534,20 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes= ['ir_r4_k5_s2_e6_c192_se0.25'], ['ir_r1_k3_s1_e6_c320_se0.25'], ] - num_features = _round_channels(1280, channel_multiplier, 8, None) - model = GenEfficientNet( - _decode_arch_def(arch_def, depth_multiplier), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def, depth_multiplier), + num_features=_round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, - num_features=num_features, - bn_args=_resolve_bn_args(kwargs), - act_fn=swish, - **kwargs + norm_kwargs=_resolve_bn_args(kwargs), + act_layer=Swish, + **kwargs, ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_efficientnet_edge(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs): +def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): arch_def = [ # NOTE `fc` is present to override a mismatch between stem channels and in chs not # present in other models @@ -1285,21 +1558,46 @@ def _gen_efficientnet_edge(channel_multiplier=1.0, depth_multiplier=1.0, num_cla ['ir_r4_k5_s1_e8_c144'], ['ir_r2_k5_s2_e8_c192'], ] - num_features = _round_channels(1280, channel_multiplier, 8, None) - model = GenEfficientNet( - _decode_arch_def(arch_def, depth_multiplier), - num_classes=num_classes, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def, depth_multiplier), + num_features=_round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, - num_features=num_features, - bn_args=_resolve_bn_args(kwargs), - act_fn=F.relu, - **kwargs + norm_kwargs=_resolve_bn_args(kwargs), + act_layer=nn.ReLU, + **kwargs, + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_efficientnet_condconv( + variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + + """Creates an efficientnet-condconv model.""" + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ] + model_kwargs = dict( + block_args=_decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=_round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=_resolve_bn_args(kwargs), + act_layer=Swish, + **kwargs, ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs): +def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MixNet Small model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet @@ -1320,20 +1618,19 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs): ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish # 7x7 ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=16, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def), num_features=1536, + stem_size=16, channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - act_fn=F.relu, + norm_kwargs=_resolve_bn_args(kwargs), **kwargs ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model -def _gen_mixnet_m(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs): +def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): """Creates a MixNet Medium-Large model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet @@ -1354,672 +1651,524 @@ def _gen_mixnet_m(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000 ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish # 7x7 ] - model = GenEfficientNet( - _decode_arch_def(arch_def, depth_multiplier=depth_multiplier, depth_trunc='round'), - num_classes=num_classes, - stem_size=24, + model_kwargs = dict( + block_args=_decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), num_features=1536, + stem_size=24, channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - act_fn=F.relu, + norm_kwargs=_resolve_bn_args(kwargs), **kwargs ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) return model @register_model -def mnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mnasnet_050(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 0.5. """ - default_cfg = default_cfgs['mnasnet_050'] - model = _gen_mnasnet_b1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) return model @register_model -def mnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mnasnet_075(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 0.75. """ - default_cfg = default_cfgs['mnasnet_075'] - model = _gen_mnasnet_b1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) return model @register_model -def mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mnasnet_100(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.0. """ - default_cfg = default_cfgs['mnasnet_100'] - model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def mnasnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mnasnet_b1(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.0. """ - return mnasnet_100(pretrained, num_classes, in_chans, **kwargs) + return mnasnet_100(pretrained, **kwargs) @register_model -def mnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mnasnet_140(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.4 """ - default_cfg = default_cfgs['mnasnet_140'] - model = _gen_mnasnet_b1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) return model @register_model -def semnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def semnasnet_050(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ - default_cfg = default_cfgs['semnasnet_050'] - model = _gen_mnasnet_a1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) return model @register_model -def semnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def semnasnet_075(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ - default_cfg = default_cfgs['semnasnet_075'] - model = _gen_mnasnet_a1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) return model @register_model -def semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def semnasnet_100(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ - default_cfg = default_cfgs['semnasnet_100'] - model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def mnasnet_a1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mnasnet_a1(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ - return semnasnet_100(pretrained, num_classes, in_chans, **kwargs) + return semnasnet_100(pretrained, **kwargs) @register_model -def semnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def semnasnet_140(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ - default_cfg = default_cfgs['semnasnet_140'] - model = _gen_mnasnet_a1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) return model @register_model -def mnasnet_small(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mnasnet_small(pretrained=False, **kwargs): """ MNASNet Small, depth multiplier of 1.0. """ - default_cfg = default_cfgs['mnasnet_small'] - model = _gen_mnasnet_small(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mobilenetv1_100(pretrained=False, **kwargs): """ MobileNet V1 """ - default_cfg = default_cfgs['mobilenetv1_100'] - model = _gen_mobilenet_v1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mobilenet_v1('mobilenetv1_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mobilenetv2_100(pretrained=False, **kwargs): """ MobileNet V2 """ - default_cfg = default_cfgs['mobilenetv2_100'] - model = _gen_mobilenet_v2(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mobilenetv3_050(pretrained=False, **kwargs): """ MobileNet V3 """ - default_cfg = default_cfgs['mobilenetv3_050'] - model = _gen_mobilenet_v3(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mobilenet_v3('mobilenetv3_050', 0.5, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mobilenetv3_075(pretrained=False, **kwargs): """ MobileNet V3 """ - default_cfg = default_cfgs['mobilenetv3_075'] - model = _gen_mobilenet_v3(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mobilenet_v3('mobilenetv3_075', 0.75, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mobilenetv3_100(pretrained=False, **kwargs): """ MobileNet V3 """ - default_cfg = default_cfgs['mobilenetv3_100'] if pretrained: # pretrained model trained with non-default BN epsilon kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - model = _gen_mobilenet_v3(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_mobilenet_v3('mobilenetv3_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def fbnetc_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def fbnetc_100(pretrained=False, **kwargs): """ FBNet-C """ - default_cfg = default_cfgs['fbnetc_100'] if pretrained: # pretrained model trained with non-default BN epsilon kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - model = _gen_fbnetc(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def chamnetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def chamnetv1_100(pretrained=False, **kwargs): """ ChamNet """ - default_cfg = default_cfgs['chamnetv1_100'] - model = _gen_chamnet_v1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_chamnet_v1('chamnetv1_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def chamnetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def chamnetv2_100(pretrained=False, **kwargs): """ ChamNet """ - default_cfg = default_cfgs['chamnetv2_100'] - model = _gen_chamnet_v2(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_chamnet_v2('chamnetv2_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def spnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def spnasnet_100(pretrained=False, **kwargs): """ Single-Path NAS Pixel1""" - default_cfg = default_cfgs['spnasnet_100'] - model = _gen_spnasnet(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0 """ - default_cfg = default_cfgs['efficientnet_b0'] # NOTE for train, drop_rate should be 0.2 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg model = _gen_efficientnet( - channel_multiplier=1.0, depth_multiplier=1.0, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model @register_model -def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_b1(pretrained=False, **kwargs): """ EfficientNet-B1 """ - default_cfg = default_cfgs['efficientnet_b1'] # NOTE for train, drop_rate should be 0.2 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg model = _gen_efficientnet( - channel_multiplier=1.0, depth_multiplier=1.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model @register_model -def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_b2(pretrained=False, **kwargs): """ EfficientNet-B2 """ - default_cfg = default_cfgs['efficientnet_b2'] # NOTE for train, drop_rate should be 0.3 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg model = _gen_efficientnet( - channel_multiplier=1.1, depth_multiplier=1.2, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model @register_model -def efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_b3(pretrained=False, **kwargs): """ EfficientNet-B3 """ - default_cfg = default_cfgs['efficientnet_b3'] # NOTE for train, drop_rate should be 0.3 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg model = _gen_efficientnet( - channel_multiplier=1.2, depth_multiplier=1.4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model @register_model -def efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_b4(pretrained=False, **kwargs): """ EfficientNet-B4 """ - default_cfg = default_cfgs['efficientnet_b4'] # NOTE for train, drop_rate should be 0.4 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg model = _gen_efficientnet( - channel_multiplier=1.4, depth_multiplier=1.8, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model @register_model -def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_b5(pretrained=False, **kwargs): """ EfficientNet-B5 """ # NOTE for train, drop_rate should be 0.4 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - default_cfg = default_cfgs['efficientnet_b5'] model = _gen_efficientnet( - channel_multiplier=1.6, depth_multiplier=2.2, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) return model @register_model -def efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_b6(pretrained=False, **kwargs): """ EfficientNet-B6 """ # NOTE for train, drop_rate should be 0.5 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - default_cfg = default_cfgs['efficientnet_b6'] model = _gen_efficientnet( - channel_multiplier=1.8, depth_multiplier=2.6, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) return model @register_model -def efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_b7(pretrained=False, **kwargs): """ EfficientNet-B7 """ # NOTE for train, drop_rate should be 0.5 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - default_cfg = default_cfgs['efficientnet_b7'] model = _gen_efficientnet( - channel_multiplier=2.0, depth_multiplier=3.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) return model @register_model -def efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_es(pretrained=False, **kwargs): """ EfficientNet-Edge Small. """ - default_cfg = default_cfgs['efficientnet_es'] model = _gen_efficientnet_edge( - channel_multiplier=1.0, depth_multiplier=1.0, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model @register_model -def efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_em(pretrained=False, **kwargs): """ EfficientNet-Edge-Medium. """ - default_cfg = default_cfgs['efficientnet_em'] model = _gen_efficientnet_edge( - channel_multiplier=1.0, depth_multiplier=1.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model @register_model -def efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_el(pretrained=False, **kwargs): """ EfficientNet-Edge-Large. """ - default_cfg = default_cfgs['efficientnet_el'] model = _gen_efficientnet_edge( - channel_multiplier=1.2, depth_multiplier=1.4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2 + #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model @register_model -def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2 + #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2 + #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b0'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( - channel_multiplier=1.0, depth_multiplier=1.0, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model @register_model -def tf_efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_efficientnet_b1(pretrained=False, **kwargs): """ EfficientNet-B1. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b1'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( - channel_multiplier=1.0, depth_multiplier=1.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model @register_model -def tf_efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_efficientnet_b2(pretrained=False, **kwargs): """ EfficientNet-B2. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b2'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( - channel_multiplier=1.1, depth_multiplier=1.2, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model @register_model def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B3. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b3'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( - channel_multiplier=1.2, depth_multiplier=1.4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model @register_model -def tf_efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_efficientnet_b4(pretrained=False, **kwargs): """ EfficientNet-B4. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b4'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( - channel_multiplier=1.4, depth_multiplier=1.8, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model @register_model -def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_efficientnet_b5(pretrained=False, **kwargs): """ EfficientNet-B5. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b5'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( - channel_multiplier=1.6, depth_multiplier=2.2, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) return model @register_model -def tf_efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_efficientnet_b6(pretrained=False, **kwargs): """ EfficientNet-B6. Tensorflow compatible variant """ # NOTE for train, drop_rate should be 0.5 - default_cfg = default_cfgs['tf_efficientnet_b6'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( - channel_multiplier=1.8, depth_multiplier=2.6, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) return model @register_model -def tf_efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_efficientnet_b7(pretrained=False, **kwargs): """ EfficientNet-B7. Tensorflow compatible variant """ # NOTE for train, drop_rate should be 0.5 - default_cfg = default_cfgs['tf_efficientnet_b7'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( - channel_multiplier=2.0, depth_multiplier=3.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) return model @register_model -def tf_efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_efficientnet_es(pretrained=False, **kwargs): """ EfficientNet-Edge Small. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_es'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( - channel_multiplier=1.0, depth_multiplier=1.0, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model @register_model -def tf_efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_efficientnet_em(pretrained=False, **kwargs): """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_em'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( - channel_multiplier=1.0, depth_multiplier=1.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model @register_model -def tf_efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_efficientnet_el(pretrained=False, **kwargs): """ EfficientNet-Edge-Large. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_el'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( - channel_multiplier=1.2, depth_multiplier=1.4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train, drop_rate should be 0.2 + #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train, drop_rate should be 0.2 + #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + +@register_model +def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train, drop_rate should be 0.2 + #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) return model @register_model -def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mixnet_s(pretrained=False, **kwargs): """Creates a MixNet Small model. """ - default_cfg = default_cfgs['mixnet_s'] model = _gen_mixnet_s( - channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model @register_model -def mixnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mixnet_m(pretrained=False, **kwargs): """Creates a MixNet Medium model. """ - default_cfg = default_cfgs['mixnet_m'] model = _gen_mixnet_m( - channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model @register_model -def mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mixnet_l(pretrained=False, **kwargs): """Creates a MixNet Large model. """ - default_cfg = default_cfgs['mixnet_l'] model = _gen_mixnet_m( - channel_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) return model @register_model -def mixnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mixnet_xl(pretrained=False, **kwargs): """Creates a MixNet Extra-Large model. Not a paper spec, experimental def by RW w/ depth scaling. """ - default_cfg = default_cfgs['mixnet_xl'] - #kwargs['drop_connect_rate'] = 0.2 model = _gen_mixnet_m( - channel_multiplier=1.6, depth_multiplier=1.2, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model @register_model -def mixnet_xxl(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def mixnet_xxl(pretrained=False, **kwargs): """Creates a MixNet Double Extra Large model. Not a paper spec, experimental def by RW w/ depth scaling. """ - default_cfg = default_cfgs['mixnet_xxl'] # kwargs['drop_connect_rate'] = 0.2 model = _gen_mixnet_m( - channel_multiplier=2.4, depth_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) return model @register_model -def tf_mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_mixnet_s(pretrained=False, **kwargs): """Creates a MixNet Small model. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_mixnet_s'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_s( - channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model @register_model -def tf_mixnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_mixnet_m(pretrained=False, **kwargs): """Creates a MixNet Medium model. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_mixnet_m'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_m( - channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model @register_model -def tf_mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_mixnet_l(pretrained=False, **kwargs): """Creates a MixNet Large model. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_mixnet_l'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_m( - channel_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) + 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) return model diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 9ac728da..7460f4a2 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -57,7 +57,7 @@ def resume_checkpoint(model, checkpoint_path): raise FileNotFoundError() -def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None): +def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True): if cfg is None: cfg = getattr(model, 'default_cfg') if cfg is None or 'url' not in cfg or not cfg['url']: @@ -74,7 +74,6 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non elif in_chans != 3: assert False, "Invalid in_chans for pretrained weights" - strict = True classifier_name = cfg['classifier'] if num_classes == 1000 and cfg['num_classes'] == 1001: # special case for imagenet trained models with extra background class in pretrained weights