diff --git a/clean_checkpoint.py b/clean_checkpoint.py index 59a6e306..d51e0d96 100644 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -2,6 +2,7 @@ import torch import argparse import os import hashlib +import shutil from collections import OrderedDict parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') @@ -31,10 +32,9 @@ def main(): if state_dict_key in checkpoint: state_dict = checkpoint[state_dict_key] else: - print("Error: No state_dict found in checkpoint {}.".format(args.checkpoint)) - exit(1) + state_dict = checkpoint else: - state_dict = checkpoint + assert False for k, v in state_dict.items(): name = k[7:] if k.startswith('module') else k new_state_dict[name] = v @@ -43,7 +43,11 @@ def main(): torch.save(new_state_dict, args.output) with open(args.output, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() - print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) + + checkpoint_base = os.path.splitext(args.checkpoint)[0] + final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' + shutil.move(args.output, final_filename) + print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) else: print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) diff --git a/requirements.txt b/requirements.txt index 88ce152f..f05f9812 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -torch>=1.1.0 -torchvision>=0.3.0 +torch>=1.2.0 +torchvision>=0.4.0 pyyaml diff --git a/sotabench.py b/sotabench.py index 5b61a93f..cd25412f 100644 --- a/sotabench.py +++ b/sotabench.py @@ -78,7 +78,7 @@ model_list = [ _entry('mixnet_m', 'MixNet-M', '1907.09595'), _entry('mixnet_s', 'MixNet-S', '1907.09595'), _entry('mnasnet_100', 'MnasNet-B1', '1807.11626'), - _entry('mobilenetv3_100', 'MobileNet V3-Large 1.0', '1905.02244', + _entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244', model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching ' 'paper as closely as possible.'), _entry('resnet18', 'ResNet-18', '1812.01187'), @@ -114,6 +114,30 @@ model_list = [ model_desc='Ported from official Google AI Tensorflow weights'), _entry('tf_efficientnet_b7', 'EfficientNet-B7 (RandAugment)', '1905.11946', batch_size=BATCH_SIZE//8, model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_b0_ap', 'EfficientNet-B0 (AdvProp)', '1911.09665', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_b1_ap', 'EfficientNet-B1 (AdvProp)', '1911.09665', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_b2_ap', 'EfficientNet-B2 (AdvProp)', '1911.09665', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_b3_ap', 'EfficientNet-B3 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 2, + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_b4_ap', 'EfficientNet-B4 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 2, + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_b5_ap', 'EfficientNet-B5 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 4, + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_b6_ap', 'EfficientNet-B6 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 8, + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_b7_ap', 'EfficientNet-B7 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 8, + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_b8_ap', 'EfficientNet-B8 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 8, + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_cc_b0_4e', 'EfficientNet-CondConv-B0 4 experts', '1904.04971', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_cc_b0_8e', 'EfficientNet-CondConv-B0 8 experts', '1904.04971', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_cc_b1_8e', 'EfficientNet-CondConv-B1 8 experts', '1904.04971', + model_desc='Ported from official Google AI Tensorflow weights'), _entry('tf_efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946', model_desc='Ported from official Google AI Tensorflow weights'), _entry('tf_efficientnet_em', 'EfficientNet-EdgeTPU-M', '1905.11946', @@ -124,6 +148,18 @@ model_list = [ _entry('tf_mixnet_l', 'MixNet-L', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'), _entry('tf_mixnet_m', 'MixNet-M', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'), _entry('tf_mixnet_s', 'MixNet-S', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_mobilenetv3_large_100', 'MobileNet V3-Large 1.0', '1905.02244', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_mobilenetv3_large_075', 'MobileNet V3-Large 0.75', '1905.02244', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_mobilenetv3_large_minimal_100', 'MobileNet V3-Large Minimal 1.0', '1905.02244', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_mobilenetv3_small_100', 'MobileNet V3-Small 1.0', '1905.02244', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_mobilenetv3_small_075', 'MobileNet V3-Small 0.75', '1905.02244', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_mobilenetv3_small_minimal_100', 'MobileNet V3-Small Minimal 1.0', '1905.02244', + model_desc='Ported from official Google AI Tensorflow weights'), ## Cadene ported weights (to remove if Cadene adds sotabench) _entry('inception_resnet_v2', 'Inception ResNet V2', '1602.07261'), diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 3c7e8e47..7119c4f5 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -7,12 +7,14 @@ from .senet import * from .xception import * from .nasnet import * from .pnasnet import * -from .gen_efficientnet import * +from .efficientnet import * +from .mobilenetv3 import * from .inception_v3 import * from .gluon_resnet import * from .gluon_xception import * from .res2net import * from .dla import * +from .hrnet import * from .registry import * from .factory import create_model diff --git a/timm/models/activations.py b/timm/models/activations.py new file mode 100644 index 00000000..aafa290c --- /dev/null +++ b/timm/models/activations.py @@ -0,0 +1,155 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + + +_USE_MEM_EFFICIENT_ISH = True +if _USE_MEM_EFFICIENT_ISH: + # This version reduces memory overhead of Swish during training by + # recomputing torch.sigmoid(x) in backward instead of saving it. + @torch.jit.script + def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + + @torch.jit.script + def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + + class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + + def swish(x, _inplace=False): + return SwishJitAutoFn.apply(x) + + + @torch.jit.script + def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + + @torch.jit.script + def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + + class MishJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + def mish(x, _inplace=False): + return MishJitAutoFn.apply(x) + +else: + def swish(x, inplace=False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + + def mish(x, _inplace=False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class Swish(nn.Module): + def __init__(self, inplace=False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +class Mish(nn.Module): + def __init__(self, inplace=False): + super(Mish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return mish(x, self.inplace) + + +def sigmoid(x, inplace=False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace=False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace=False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace=False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace=False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace=False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace=False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace=False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + diff --git a/timm/models/conv2d_helpers.py b/timm/models/conv2d_helpers.py deleted file mode 100644 index 674eadca..00000000 --- a/timm/models/conv2d_helpers.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math - - -def _is_static_pad(kernel_size, stride=1, dilation=1, **_): - return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 - - -def _get_padding(kernel_size, stride=1, dilation=1, **_): - padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 - return padding - - -def _calc_same_pad(i, k, s, d): - return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) - - -def _split_channels(num_chan, num_groups): - split = [num_chan // num_groups for _ in range(num_groups)] - split[0] += num_chan - sum(split) - return split - - -class Conv2dSame(nn.Conv2d): - """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions - """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True): - super(Conv2dSame, self).__init__( - in_channels, out_channels, kernel_size, stride, 0, dilation, - groups, bias) - - def forward(self, x): - ih, iw = x.size()[-2:] - kh, kw = self.weight.size()[-2:] - pad_h = _calc_same_pad(ih, kh, self.stride[0], self.dilation[0]) - pad_w = _calc_same_pad(iw, kw, self.stride[1], self.dilation[1]) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2]) - return F.conv2d(x, self.weight, self.bias, self.stride, - self.padding, self.dilation, self.groups) - - -def conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): - padding = kwargs.pop('padding', '') - kwargs.setdefault('bias', False) - if isinstance(padding, str): - # for any string padding, the padding will be calculated for you, one of three ways - padding = padding.lower() - if padding == 'same': - # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact - if _is_static_pad(kernel_size, **kwargs): - # static case, no extra overhead - padding = _get_padding(kernel_size, **kwargs) - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) - else: - # dynamic padding - return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) - elif padding == 'valid': - # 'VALID' padding, same as padding=0 - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs) - else: - # Default to PyTorch style 'same'-ish symmetric padding - padding = _get_padding(kernel_size, **kwargs) - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) - else: - # padding was specified as a number or pair - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) - - -class MixedConv2d(nn.Module): - """ Mixed Grouped Convolution - Based on MDConv and GroupedConv in MixNet impl: - https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py - """ - - def __init__(self, in_channels, out_channels, kernel_size=3, - stride=1, padding='', dilated=False, depthwise=False, **kwargs): - super(MixedConv2d, self).__init__() - - kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] - num_groups = len(kernel_size) - in_splits = _split_channels(in_channels, num_groups) - out_splits = _split_channels(out_channels, num_groups) - for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): - d = 1 - # FIXME make compat with non-square kernel/dilations/strides - if stride == 1 and dilated: - d, k = (k - 1) // 2, 3 - conv_groups = out_ch if depthwise else 1 - # use add_module to keep key space clean - self.add_module( - str(idx), - conv2d_pad( - in_ch, out_ch, k, stride=stride, - padding=padding, dilation=d, groups=conv_groups, **kwargs) - ) - self.splits = in_splits - - def forward(self, x): - x_split = torch.split(x, self.splits, 1) - x_out = [c(x) for x, c in zip(x_split, self._modules.values())] - x = torch.cat(x_out, 1) - return x - - -# helper method -def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): - assert 'groups' not in kwargs # only use 'depthwise' bool arg - if isinstance(kernel_size, list): - # We're going to use only lists for defining the MixedConv2d kernel groups, - # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. - return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) - else: - depthwise = kwargs.pop('depthwise', False) - groups = out_chs if depthwise else 1 - return conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) - diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py new file mode 100644 index 00000000..acd14fde --- /dev/null +++ b/timm/models/conv2d_layers.py @@ -0,0 +1,260 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch._six import container_abcs +from itertools import repeat +from functools import partial +import numpy as np +import math + + +# Tuple helpers ripped from PyTorch +def _ntuple(n): + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def _get_padding(kernel_size, stride=1, dilation=1, **_): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def _calc_same_pad(i, k, s, d): + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +# pylint: disable=unused-argument +def conv2d_same(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1): + ih, iw = x.size()[-2:] + kh, kw = weight.size()[-2:] + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def get_padding_value(padding, kernel_size, **kwargs): + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if _is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = _get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = _get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + +class MixedConv2d(nn.Module): + """ Mixed Grouped Convolution + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + + NOTE: This does not currently work with torch.jit.script + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = out_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x) for x, c in zip(x_split, self._modules.values())] + x = torch.cat(x_out, 1) + return x + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditional Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = _pair(padding_val) + self.dilation = _pair(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out + + +# helper method +def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): + assert 'groups' not in kwargs # only use 'depthwise' bool arg + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_chs if depthwise else 1 + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + return m + + diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py new file mode 100644 index 00000000..9163a023 --- /dev/null +++ b/timm/models/efficientnet.py @@ -0,0 +1,1287 @@ +""" PyTorch EfficientNet Family + +An implementation of EfficienNet that covers variety of related models with efficient architectures: + +* EfficientNet (B0-B8 + Tensorflow pretrained AutoAug/RandAug/AdvProp weight ports) + - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 + - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 + - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 + +* MixNet (Small, Medium, and Large) + - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 + +* MNasNet B1, A1 (SE), Small + - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 + +* FBNet-C + - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 + +* Single-Path NAS Pixel1 + - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 + +* And likely more... + +Hacked together by Ross Wightman +""" +from .efficientnet_builder import * +from .feature_hooks import FeatureHooks +from .registry import register_model +from .helpers import load_pretrained +from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .conv2d_layers import select_conv2d +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD + + +__all__ = ['EfficientNet'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'mnasnet_050': _cfg(url=''), + 'mnasnet_075': _cfg(url=''), + 'mnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'), + 'mnasnet_140': _cfg(url=''), + 'semnasnet_050': _cfg(url=''), + 'semnasnet_075': _cfg(url=''), + 'semnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'), + 'semnasnet_140': _cfg(url=''), + 'mnasnet_small': _cfg(url=''), + 'mobilenetv2_100': _cfg(url=''), + 'fbnetc_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + interpolation='bilinear'), + 'spnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', + interpolation='bilinear'), + 'efficientnet_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'), + 'efficientnet_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'efficientnet_b3': _cfg( + url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_b4': _cfg( + url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'efficientnet_b5': _cfg( + url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'efficientnet_b6': _cfg( + url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'efficientnet_b7': _cfg( + url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'efficientnet_b8': _cfg( + url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + 'efficientnet_es': _cfg( + url=''), + 'efficientnet_em': _cfg( + url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_el': _cfg( + url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_cc_b0_4e': _cfg(url=''), + 'efficientnet_cc_b0_8e': _cfg(url=''), + 'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b0_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)), + 'tf_efficientnet_b1_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + 'tf_efficientnet_es': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 224, 224), ), + 'tf_efficientnet_em': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_el': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_cc_b0_4e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b0_8e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b1_8e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'mixnet_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'), + 'mixnet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'), + 'mixnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'), + 'mixnet_xl': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl-ac5fbe8d.pth'), + 'mixnet_xxl': _cfg(), + 'tf_mixnet_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'), + 'tf_mixnet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth'), + 'tf_mixnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), +} + +_DEBUG = False + + +class EfficientNet(nn.Module): + """ (Generic) EfficientNet + + A flexible and performant PyTorch implementation of efficient network architectures, including: + * EfficientNet B0-B8 + * EfficientNet-EdgeTPU + * EfficientNet-CondConv + * MixNet S, M, L, XL + * MnasNet A1, B1, and small + * FBNet C + * Single-Path NAS Pixel1 + + """ + + def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, + channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + global_pool='avg', weight_init='goog'): + super(EfficientNet, self).__init__() + norm_kwargs = norm_kwargs or {} + + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + self._in_chs = in_chans + + # Stem + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + channel_multiplier, channel_divisor, channel_min, 32, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.feature_info = builder.features + self._in_chs = builder.in_chs + + # Head + Pooling + self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type) + self.bn2 = norm_layer(self.num_features, **norm_kwargs) + self.act2 = act_layer(inplace=True) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + + # Classifier + self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes) + + for m in self.modules(): + if weight_init == 'goog': + efficientnet_init_goog(m) + else: + efficientnet_init_default(m) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool]) + layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.num_classes = num_classes + del self.classifier + if num_classes: + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), num_classes) + else: + self.classifier = None + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +class EfficientNetFeatures(nn.Module): + """ EfficientNet Feature Extractor + + A work-in-progress feature extraction module for EfficientNet, to use as a backbone for segmentation + and object detection models. + """ + + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', + in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, + output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): + super(EfficientNetFeatures, self).__init__() + norm_kwargs = norm_kwargs or {} + + # TODO only create stages needed, currently all stages are created regardless of out_indices + num_stages = max(out_indices) + 1 + + self.out_indices = out_indices + self.drop_rate = drop_rate + self._in_chs = in_chans + + # Stem + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.feature_info = builder.features # builder provides info about feature channels for each block + self._in_chs = builder.in_chs + + for m in self.modules(): + if weight_init == 'goog': + efficientnet_init_goog(m) + else: + efficientnet_init_default(m) + + if _DEBUG: + for k, v in self.feature_info.items(): + print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) + + # Register feature extraction hooks with FeatureHooks helper + hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward' + hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices] + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + + def feature_channels(self, idx=None): + """ Feature Channel Shortcut + Returns feature channel count for each output index if idx == None. If idx is an integer, will + return feature channel count for that feature block index (independent of out_indices setting). + """ + if isinstance(idx, int): + return self.feature_info[idx]['num_chs'] + return [self.feature_info[i]['num_chs'] for i in self.out_indices] + + def forward(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + self.blocks(x) + return self.feature_hooks.get_output(x.device) + + +def _create_model(model_kwargs, default_cfg, pretrained=False): + if model_kwargs.pop('features_only', False): + load_strict = False + model_kwargs.pop('num_classes', 0) + model_kwargs.pop('num_features', 0) + model_kwargs.pop('head_conv', None) + model_class = EfficientNetFeatures + else: + load_strict = True + model_class = EfficientNet + + model = model_class(**model_kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=model_kwargs.get('num_classes', 0), + in_chans=model_kwargs.get('in_chans', 3), + strict=load_strict) + return model + + +def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-a1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r2_k3_s2_e6_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r4_k3_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r3_k5_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + ['ds_r1_k3_s1_c8'], + ['ir_r1_k3_s2_e3_c16'], + ['ir_r2_k3_s2_e6_c16'], + ['ir_r4_k5_s2_e6_c32_se0.25'], + ['ir_r3_k3_s1_e6_c32_se0.25'], + ['ir_r3_k5_s2_e6_c88_se0.25'], + ['ir_r1_k3_s1_e6_c144'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=8, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_mobilenet_v2(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ Generate MobileNet-V2 network + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['ds_r1_k3_s1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r3_k3_s2_e6_c32'], + ['ir_r4_k3_s2_e6_c64'], + ['ir_r3_k3_s1_e6_c96'], + ['ir_r3_k3_s2_e6_c160'], + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=nn.ReLU6, + **kwargs + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNet-C + + Paper: https://arxiv.org/abs/1812.03443 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, + it was used to confirm some building block details + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], + ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], + ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], + ['ir_r4_k5_s2_e6_c184'], + ['ir_r1_k3_s1_e6_c352'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + num_features=1984, # paper suggests this, but is not 100% clear + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates the Single-Path NAS model from search targeted for Pixel1 phone. + + Paper: https://arxiv.org/abs/1904.02877 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], + # stage 4, 14x14in + ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=Swish, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-EdgeTPU model + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu + """ + + arch_def = [ + # NOTE `fc` is present to override a mismatch between stem channels and in chs not + # present in other models + ['er_r1_k3_s1_e4_c24_fc24_noskip'], + ['er_r2_k3_s2_e8_c32'], + ['er_r4_k3_s2_e8_c48'], + ['ir_r5_k5_s2_e8_c96'], + ['ir_r4_k5_s1_e8_c144'], + ['ir_r2_k5_s2_e8_c192'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=nn.ReLU, + **kwargs, + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_efficientnet_condconv( + variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + """Creates an EfficientNet-CondConv model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ] + # NOTE unlike official impl, this one uses `cc` option where x is the base number of experts for each stage and + # the expert_multiplier increases that on a per-model basis as with depth/channel multipliers + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=Swish, + **kwargs, + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Small model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1536, + stem_size=16, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Medium-Large model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c24'], # relu + # stage 1, 112x112 in + ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=1536, + stem_size=24, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +@register_model +def mnasnet_050(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.5. """ + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_075(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.75. """ + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_100(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_b1(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + return mnasnet_100(pretrained, **kwargs) + + +@register_model +def mnasnet_140(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.4 """ + model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_050(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ + model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_075(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ + model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_100(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_a1(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + return semnasnet_100(pretrained, **kwargs) + + +@register_model +def semnasnet_140(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ + model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_small(pretrained=False, **kwargs): + """ MNASNet Small, depth multiplier of 1.0. """ + model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_100(pretrained=False, **kwargs): + """ MobileNet V2 """ + model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetc_100(pretrained=False, **kwargs): + """ FBNet-C """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def spnasnet_100(pretrained=False, **kwargs): + """ Single-Path NAS Pixel1""" + model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 """ + # NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 """ + # NOTE for train, drop_rate should be 0.3, drop_connect_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 """ + # NOTE for train, drop_rate should be 0.3, drop_connect_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 """ + # NOTE for train, drop_rate should be 0.4, drop_connect_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 """ + # NOTE for train, drop_rate should be 0.4, drop_connect_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 """ + # NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 """ + # NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. """ + model = _gen_efficientnet_edge( + 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. """ + model = _gen_efficientnet_edge( + 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. """ + model = _gen_efficientnet_edge( + 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0_ap(pretrained=False, **kwargs): + """ EfficientNet-B0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1_ap(pretrained=False, **kwargs): + """ EfficientNet-B1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2_ap(pretrained=False, **kwargs): + """ EfficientNet-B2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3_ap(pretrained=False, **kwargs): + """ EfficientNet-B3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4_ap(pretrained=False, **kwargs): + """ EfficientNet-B4. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5_ap(pretrained=False, **kwargs): + """ EfficientNet-B5. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6_ap(pretrained=False, **kwargs): + """ EfficientNet-B6. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7_ap(pretrained=False, **kwargs): + """ EfficientNet-B7. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b8_ap(pretrained=False, **kwargs): + """ EfficientNet-B7. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + + +@register_model +def tf_efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + +@register_model +def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. + """ + model = _gen_mixnet_s( + 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. + """ + model = _gen_mixnet_m( + 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. + """ + model = _gen_mixnet_m( + 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_xl(pretrained=False, **kwargs): + """Creates a MixNet Extra-Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + model = _gen_mixnet_m( + 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_xxl(pretrained=False, **kwargs): + """Creates a MixNet Double Extra Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + model = _gen_mixnet_m( + 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_s( + 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py new file mode 100644 index 00000000..13ab051a --- /dev/null +++ b/timm/models/efficientnet_blocks.py @@ -0,0 +1,404 @@ + +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .activations import sigmoid +from .conv2d_layers import * + + +# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) + + +def get_bn_args_tf(): + return _BN_ARGS_TF.copy() + + +def resolve_bn_args(kwargs): + bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +_SE_ARGS_DEFAULT = dict( + gate_fn=sigmoid, + act_layer=None, + reduce_mid=False, + divisor=1) + + +def resolve_se_args(kwargs, in_chs, act_layer=None): + se_kwargs = kwargs.copy() if kwargs is not None else {} + # fill in args that aren't specified with the defaults + for k, v in _SE_ARGS_DEFAULT.items(): + se_kwargs.setdefault(k, v) + # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch + if not se_kwargs.pop('reduce_mid'): + se_kwargs['reduced_base_chs'] = in_chs + # act_layer override, if it remains None, the containing block's act_layer will be used + if se_kwargs['act_layer'] is None: + assert act_layer is not None + se_kwargs['act_layer'] = act_layer + return se_kwargs + + +def make_divisible(v, divisor=8, min_value=None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + channels *= multiplier + return make_divisible(channels, divisor, channel_min) + + +def drop_connect(inputs, training=False, drop_connect_rate=0.): + """Apply drop connect.""" + if not training: + return inputs + + keep_prob = 1 - drop_connect_rate + random_tensor = keep_prob + torch.rand( + (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) + random_tensor.floor_() # binarize + output = inputs.div(keep_prob) * random_tensor + return output + + +class ChannelShuffle(nn.Module): + # FIXME haven't used yet + def __init__(self, groups): + super(ChannelShuffle, self).__init__() + self.groups = groups + + def forward(self, x): + """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" + N, C, H, W = x.size() + g = self.groups + assert C % g == 0, "Incompatible group size {} for input channel {}".format( + g, C + ) + return ( + x.view(N, g, int(C / g), H, W) + .permute(0, 2, 1, 3, 4) + .contiguous() + .view(N, C, H, W) + ) + + +class SqueezeExcite(nn.Module): + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, + act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(ConvBnAct, self).__init__() + norm_kwargs = norm_kwargs or {} + self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn1 = norm_layer(out_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + def feature_module(self, location): + return 'act1' + + def feature_channels(self, location): + return self.conv.out_channels + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +class DepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion + (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. + """ + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + norm_kwargs = norm_kwargs or {} + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.has_pw_act = pw_act # activation after point-wise conv + self.drop_connect_rate = drop_connect_rate + + self.conv_dw = select_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) + self.bn1 = norm_layer(in_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if self.has_se: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) + + self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() + + def feature_module(self, location): + # no expansion in this block, pre pw only feature extraction point + return 'conv_pw' + + def feature_channels(self, location): + return self.conv_pw.in_channels + + def forward(self, x): + residual = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + + if self.has_se: + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + x = self.act2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE and CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + conv_kwargs=None, drop_connect_rate=0.): + super(InvertedResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + conv_kwargs = conv_kwargs or {} + mid_chs = make_divisible(in_chs * exp_ratio) + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Point-wise expansion + self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = select_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + if self.has_se: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def feature_module(self, location): + if location == 'post_exp': + return 'act1' + return 'conv_pwl' + + def feature_channels(self, location): + if location == 'post_exp': + return self.conv_pw.out_channels + # location == 'pre_pw' + return self.conv_pwl.in_channels + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + if self.has_se: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + + return x + + +class CondConvResidual(InvertedResidual): + """ Inverted residual block w/ CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + num_experts=0, drop_connect_rate=0.): + + self.num_experts = num_experts + conv_kwargs = dict(num_experts=self.num_experts) + + super(CondConvResidual, self).__init__( + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, + act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, + drop_connect_rate=drop_connect_rate) + + self.routing_fn = nn.Linear(in_chs, self.num_experts) + + def forward(self, x): + residual = x + + # CondConv routing + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) + + # Point-wise expansion + x = self.conv_pw(x, routing_weights) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x, routing_weights) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + if self.has_se: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x, routing_weights) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class EdgeResidual(nn.Module): + """ Residual block with expansion convolution followed by pointwise-linear w/ stride""" + + def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + drop_connect_rate=0.): + super(EdgeResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + if fake_in_chs > 0: + mid_chs = make_divisible(fake_in_chs * exp_ratio) + else: + mid_chs = make_divisible(in_chs * exp_ratio) + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Expansion convolution + self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if self.has_se: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + + # Point-wise linear projection + self.conv_pwl = select_conv2d( + mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + + def feature_module(self, location): + if location == 'post_exp': + return 'act1' + return 'conv_pwl' + + def feature_channels(self, location): + if location == 'post_exp': + return self.conv_exp.out_channels + # location == 'pre_pw' + return self.conv_pwl.in_channels + + def forward(self, x): + residual = x + + # Expansion convolution + x = self.conv_exp(x) + x = self.bn1(x) + x = self.act1(x) + + # Squeeze-and-excitation + if self.has_se: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + + return x diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py new file mode 100644 index 00000000..c2b3a801 --- /dev/null +++ b/timm/models/efficientnet_builder.py @@ -0,0 +1,402 @@ +import logging +import math +import re +from collections.__init__ import OrderedDict +from copy import deepcopy + +import torch.nn as nn +from .activations import sigmoid, HardSwish, Swish +from .efficientnet_blocks import * + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + noskip = False + for op in ops: + # string options being checked on individual basis, combine if they grow + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = nn.ReLU + elif v == 'r6': + value = nn.ReLU6 + elif v == 'hs': + value = HardSwish + elif v == 'sw': + value = Swish + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + fake_in_chs=fake_in_chs, + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1): + arch_args = [] + for stack_idx, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) + return arch_args + + +class EfficientNetBuilder: + """ Build Trunk Blocks + + This ended up being somewhat of a cross between + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py + and + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py + + """ + def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, + output_stride=32, pad_type='', act_layer=None, se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0., feature_location='', + verbose=False): + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.output_stride = output_stride + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_connect_rate = drop_connect_rate + self.feature_location = feature_location + assert feature_location in ('pre_pwl', 'post_exp', '') + self.verbose = verbose + + # state updated during build, consumed by model + self.in_chs = None + self.features = OrderedDict() + + def _round_channels(self, chs): + return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) + + def _make_block(self, ba, block_idx, block_count): + drop_connect_rate = self.drop_connect_rate * block_idx / block_count + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # FIXME this is a hack to work around mismatch in origin impl input filters + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_connect_rate'] = drop_connect_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba))) + if ba.get('num_experts', 0) > 0: + block = CondConvResidual(**ba) + else: + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_connect_rate'] = drop_connect_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba))) + block = DepthwiseSeparableConv(**ba) + elif bt == 'er': + ba['drop_connect_rate'] = drop_connect_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba))) + block = EdgeResidual(**ba) + elif bt == 'cn': + if self.verbose: + logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba))) + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + + return block + + def __call__(self, in_chs, model_block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + model_block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + if self.verbose: + logging.info('Building model trunk with %d stages...' % len(model_block_args)) + self.in_chs = in_chs + total_block_count = sum([len(x) for x in model_block_args]) + total_block_idx = 0 + current_stride = 2 + current_dilation = 1 + feature_idx = 0 + stages = [] + # outer list of block_args defines the stacks ('stages' by some conventions) + for stage_idx, stage_block_args in enumerate(model_block_args): + last_stack = stage_idx == (len(model_block_args) - 1) + if self.verbose: + logging.info('Stack: {}'.format(stage_idx)) + assert isinstance(stage_block_args, list) + + blocks = [] + # each stack (stage) contains a list of block arguments + for block_idx, block_args in enumerate(stage_block_args): + last_block = block_idx == (len(stage_block_args) - 1) + extract_features = '' # No features extracted + if self.verbose: + logging.info(' Block: {}'.format(block_idx)) + + # Sort out stride, dilation, and feature extraction details + assert block_args['stride'] in (1, 2) + if block_idx >= 1: + # only the first block in any stack can have a stride > 1 + block_args['stride'] = 1 + + do_extract = False + if self.feature_location == 'pre_pwl': + if last_block: + next_stage_idx = stage_idx + 1 + if next_stage_idx >= len(model_block_args): + do_extract = True + else: + do_extract = model_block_args[next_stage_idx][0]['stride'] > 1 + elif self.feature_location == 'post_exp': + if block_args['stride'] > 1 or (last_stack and last_block) : + do_extract = True + if do_extract: + extract_features = self.feature_location + + next_dilation = current_dilation + if block_args['stride'] > 1: + next_output_stride = current_stride * block_args['stride'] + if next_output_stride > self.output_stride: + next_dilation = current_dilation * block_args['stride'] + block_args['stride'] = 1 + if self.verbose: + logging.info(' Converting stride to dilation to maintain output_stride=={}'.format( + self.output_stride)) + else: + current_stride = next_output_stride + block_args['dilation'] = current_dilation + if next_dilation != current_dilation: + current_dilation = next_dilation + + # create the block + block = self._make_block(block_args, total_block_idx, total_block_count) + blocks.append(block) + + # stash feature module name and channel info for model feature extraction + if extract_features: + feature_module = block.feature_module(extract_features) + if feature_module: + feature_module = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_module + feature_channels = block.feature_channels(extract_features) + self.features[feature_idx] = dict( + name=feature_module, + num_chs=feature_channels + ) + feature_idx += 1 + + total_block_idx += 1 # incr global block idx (across all stacks) + stages.append(nn.Sequential(*blocks)) + return stages + + +def efficientnet_init_goog(m, n=''): + # weight init as per Tensorflow Official impl + # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + init_weight_fn = get_condconv_initializer( + lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +def efficientnet_init_default(m, n=''): + if isinstance(m, CondConv2d): + init_fn = get_condconv_initializer(partial( + nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) + init_fn(m.weight) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') + + diff --git a/timm/models/factory.py b/timm/models/factory.py index 3c051e75..7b9f7a07 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -25,8 +25,8 @@ def create_model( """ margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) - # Only gen_efficientnet models have support for batchnorm params or drop_connect_rate passed as args - is_efficientnet = is_model_in_modules(model_name, ['gen_efficientnet']) + # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args + is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) if not is_efficientnet: kwargs.pop('bn_tf', None) kwargs.pop('bn_momentum', None) diff --git a/timm/models/feature_hooks.py b/timm/models/feature_hooks.py new file mode 100644 index 00000000..8ffcda86 --- /dev/null +++ b/timm/models/feature_hooks.py @@ -0,0 +1,31 @@ +from collections import defaultdict, OrderedDict +from functools import partial + + +class FeatureHooks: + + def __init__(self, hooks, named_modules): + # setup feature hooks + modules = {k: v for k, v in named_modules} + for h in hooks: + hook_name = h['name'] + m = modules[hook_name] + hook_fn = partial(self._collect_output_hook, hook_name) + if h['type'] == 'forward_pre': + m.register_forward_pre_hook(hook_fn) + elif h['type'] == 'forward': + m.register_forward_hook(hook_fn) + else: + assert False, "Unsupported hook type" + self._feature_outputs = defaultdict(OrderedDict) + + def _collect_output_hook(self, name, *args): + x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][name] = x + + def get_output(self, device): + output = tuple(self._feature_outputs[device].values())[::-1] + self._feature_outputs[device] = OrderedDict() # clear after reading + return output diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py deleted file mode 100644 index a7191025..00000000 --- a/timm/models/gen_efficientnet.py +++ /dev/null @@ -1,2027 +0,0 @@ -""" Generic EfficientNets - -A generic class with building blocks to support a variety of models with efficient architectures: -* EfficientNet (B0-B7) -* MixNet (Small, Medium, and Large) -* MnasNet B1, A1 (SE), Small -* MobileNet V1, V2, and V3 -* FBNet-C (TODO A & B) -* ChamNet (TODO still guessing at architecture definition) -* Single-Path NAS Pixel1 -* And likely more... - -TODO not all combinations and variations have been tested. Currently working on training hyper-params... - -Hacked together by Ross Wightman -""" - -import math -import re -import logging -from copy import deepcopy - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .registry import register_model -from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d -from .conv2d_helpers import select_conv2d -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD - - -__all__ = ['GenEfficientNet'] - - -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv_stem', 'classifier': 'classifier', - **kwargs - } - - -default_cfgs = { - 'mnasnet_050': _cfg(url=''), - 'mnasnet_075': _cfg(url=''), - 'mnasnet_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'), - 'mnasnet_140': _cfg(url=''), - 'semnasnet_050': _cfg(url=''), - 'semnasnet_075': _cfg(url=''), - 'semnasnet_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'), - 'semnasnet_140': _cfg(url=''), - 'mnasnet_small': _cfg(url=''), - 'mobilenetv1_100': _cfg(url=''), - 'mobilenetv2_100': _cfg(url=''), - 'mobilenetv3_050': _cfg(url=''), - 'mobilenetv3_075': _cfg(url=''), - 'mobilenetv3_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth'), - 'chamnetv1_100': _cfg(url=''), - 'chamnetv2_100': _cfg(url=''), - 'fbnetc_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', - interpolation='bilinear'), - 'spnasnet_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', - interpolation='bilinear'), - 'efficientnet_b0': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'), - 'efficientnet_b1': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - 'efficientnet_b2': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth', - input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), - 'efficientnet_b3': _cfg( - url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - 'efficientnet_b4': _cfg( - url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), - 'efficientnet_b5': _cfg( - url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), - 'efficientnet_b6': _cfg( - url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), - 'efficientnet_b7': _cfg( - url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), - 'efficientnet_es': _cfg( - url=''), - 'efficientnet_em': _cfg( - url='', - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - 'efficientnet_el': _cfg( - url='', - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - 'tf_efficientnet_b0': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', - input_size=(3, 224, 224)), - 'tf_efficientnet_b1': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - 'tf_efficientnet_b2': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', - input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), - 'tf_efficientnet_b3': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - 'tf_efficientnet_b4': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', - input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), - 'tf_efficientnet_b5': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', - input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), - 'tf_efficientnet_b6': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', - input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), - 'tf_efficientnet_b7': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', - input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), - 'tf_efficientnet_es': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 224, 224), ), - 'tf_efficientnet_em': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - 'tf_efficientnet_el': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - 'mixnet_s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'), - 'mixnet_m': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'), - 'mixnet_l': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'), - 'mixnet_xl': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl-ac5fbe8d.pth'), - 'mixnet_xxl': _cfg(), - 'tf_mixnet_s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'), - 'tf_mixnet_m': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth'), - 'tf_mixnet_l': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), -} - - -_DEBUG = False - -# Default args for PyTorch BN impl -_BN_MOMENTUM_PT_DEFAULT = 0.1 -_BN_EPS_PT_DEFAULT = 1e-5 -_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT) - -# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per -# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) -# NOTE: momentum varies btw .99 and .9997 depending on source -# .99 in official TF TPU impl -# .9997 (/w .999 in search space) for paper -_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 -_BN_EPS_TF_DEFAULT = 1e-3 -_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT) - - -def _resolve_bn_args(kwargs): - bn_args = _BN_ARGS_TF.copy() if kwargs.pop('bn_tf', False) else _BN_ARGS_PT.copy() - bn_momentum = kwargs.pop('bn_momentum', None) - if bn_momentum is not None: - bn_args['momentum'] = bn_momentum - bn_eps = kwargs.pop('bn_eps', None) - if bn_eps is not None: - bn_args['eps'] = bn_eps - return bn_args - - -def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): - """Round number of filters based on depth multiplier.""" - if not multiplier: - return channels - - channels *= multiplier - channel_min = channel_min or divisor - new_channels = max( - int(channels + divisor / 2) // divisor * divisor, - channel_min) - # Make sure that round down does not go down by more than 10%. - if new_channels < 0.9 * channels: - new_channels += divisor - return new_channels - - -def _parse_ksize(ss): - if ss.isdigit(): - return int(ss) - else: - return [int(k) for k in ss.split('.')] - - -def _decode_block_str(block_str, depth_multiplier=1.0): - """ Decode block definition string - - Gets a list of block arg (dicts) through a string notation of arguments. - E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip - - All args can exist in any order with the exception of the leading string which - is assumed to indicate the block type. - - leading string - block type ( - ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) - r - number of repeat blocks, - k - kernel size, - s - strides (1-9), - e - expansion ratio, - c - output channels, - se - squeeze/excitation ratio - n - activation fn ('re', 'r6', 'hs', or 'sw') - Args: - block_str: a string representation of block arguments. - Returns: - A list of block args (dicts) - Raises: - ValueError: if the string def not properly specified (TODO) - """ - assert isinstance(block_str, str) - ops = block_str.split('_') - block_type = ops[0] # take the block type off the front - ops = ops[1:] - options = {} - noskip = False - for op in ops: - # string options being checked on individual basis, combine if they grow - if op == 'noskip': - noskip = True - elif op.startswith('n'): - # activation fn - key = op[0] - v = op[1:] - if v == 're': - value = F.relu - elif v == 'r6': - value = F.relu6 - elif v == 'hs': - value = hard_swish - elif v == 'sw': - value = swish - else: - continue - options[key] = value - else: - # all numeric options - splits = re.split(r'(\d.*)', op) - if len(splits) >= 2: - key, value = splits[:2] - options[key] = value - - # if act_fn is None, the model default (passed to model init) will be used - act_fn = options['n'] if 'n' in options else None - exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 - pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 - fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def - - num_repeat = int(options['r']) - # each type of block has different valid arguments, fill accordingly - if block_type == 'ir': - block_args = dict( - block_type=block_type, - dw_kernel_size=_parse_ksize(options['k']), - exp_kernel_size=exp_kernel_size, - pw_kernel_size=pw_kernel_size, - out_chs=int(options['c']), - exp_ratio=float(options['e']), - se_ratio=float(options['se']) if 'se' in options else None, - stride=int(options['s']), - act_fn=act_fn, - noskip=noskip, - ) - elif block_type == 'ds' or block_type == 'dsa': - block_args = dict( - block_type=block_type, - dw_kernel_size=_parse_ksize(options['k']), - pw_kernel_size=pw_kernel_size, - out_chs=int(options['c']), - se_ratio=float(options['se']) if 'se' in options else None, - stride=int(options['s']), - act_fn=act_fn, - pw_act=block_type == 'dsa', - noskip=block_type == 'dsa' or noskip, - ) - elif block_type == 'er': - block_args = dict( - block_type=block_type, - exp_kernel_size=_parse_ksize(options['k']), - pw_kernel_size=pw_kernel_size, - out_chs=int(options['c']), - exp_ratio=float(options['e']), - fake_in_chs=fake_in_chs, - se_ratio=float(options['se']) if 'se' in options else None, - stride=int(options['s']), - act_fn=act_fn, - noskip=noskip, - ) - elif block_type == 'cn': - block_args = dict( - block_type=block_type, - kernel_size=int(options['k']), - out_chs=int(options['c']), - stride=int(options['s']), - act_fn=act_fn, - ) - else: - assert False, 'Unknown block type (%s)' % block_type - - return block_args, num_repeat - - -def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): - """ Per-stage depth scaling - Scales the block repeats in each stage. This depth scaling impl maintains - compatibility with the EfficientNet scaling method, while allowing sensible - scaling for other models that may have multiple block arg definitions in each stage. - """ - - # We scale the total repeat count for each stage, there may be multiple - # block arg defs per stage so we need to sum. - num_repeat = sum(repeats) - if depth_trunc == 'round': - # Truncating to int by rounding allows stages with few repeats to remain - # proportionally smaller for longer. This is a good choice when stage definitions - # include single repeat stages that we'd prefer to keep that way as long as possible - num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) - else: - # The default for EfficientNet truncates repeats to int via 'ceil'. - # Any multiplier > 1.0 will result in an increased depth for every stage. - num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) - - # Proportionally distribute repeat count scaling to each block definition in the stage. - # Allocation is done in reverse as it results in the first block being less likely to be scaled. - # The first block makes less sense to repeat in most of the arch definitions. - repeats_scaled = [] - for r in repeats[::-1]: - rs = max(1, round((r / num_repeat * num_repeat_scaled))) - repeats_scaled.append(rs) - num_repeat -= r - num_repeat_scaled -= rs - repeats_scaled = repeats_scaled[::-1] - - # Apply the calculated scaling to each block arg in the stage - sa_scaled = [] - for ba, rep in zip(stack_args, repeats_scaled): - sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) - return sa_scaled - - -def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'): - arch_args = [] - for stack_idx, block_strings in enumerate(arch_def): - assert isinstance(block_strings, list) - stack_args = [] - repeats = [] - for block_str in block_strings: - assert isinstance(block_str, str) - ba, rep = _decode_block_str(block_str) - stack_args.append(ba) - repeats.append(rep) - arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) - return arch_args - - -_USE_SWISH_OPT = True -if _USE_SWISH_OPT: - @torch.jit.script - def swish_jit_fwd(x): - return x.mul(torch.sigmoid(x)) - - - @torch.jit.script - def swish_jit_bwd(x, grad_output): - x_sigmoid = torch.sigmoid(x) - return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) - - - class SwishJitAutoFn(torch.autograd.Function): - """ torch.jit.script optimised Swish - Inspired by conversation btw Jeremy Howard & Adam Pazske - https://twitter.com/jeremyphoward/status/1188251041835315200 - """ - - @staticmethod - def forward(ctx, x): - ctx.save_for_backward(x) - return swish_jit_fwd(x) - - @staticmethod - def backward(ctx, grad_output): - x = ctx.saved_tensors[0] - return swish_jit_bwd(x, grad_output) - - - def swish(x, inplace=False): - # inplace ignored - return SwishJitAutoFn.apply(x) -else: - def swish(x, inplace=False): - return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) - - -def sigmoid(x, inplace=False): - return x.sigmoid_() if inplace else x.sigmoid() - - -def hard_swish(x, inplace=False): - if inplace: - return x.mul_(F.relu6(x + 3.) / 6.) - else: - return x * F.relu6(x + 3.) / 6. - - -def hard_sigmoid(x, inplace=False): - if inplace: - return x.add_(3.).clamp_(0., 6.).div_(6.) - else: - return F.relu6(x + 3.) / 6. - - -class _BlockBuilder: - """ Build Trunk Blocks - - This ended up being somewhat of a cross between - https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py - and - https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py - - """ - def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - pad_type='', act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False, - bn_args=_BN_ARGS_PT, drop_connect_rate=0., verbose=False): - self.channel_multiplier = channel_multiplier - self.channel_divisor = channel_divisor - self.channel_min = channel_min - self.pad_type = pad_type - self.act_fn = act_fn - self.se_gate_fn = se_gate_fn - self.se_reduce_mid = se_reduce_mid - self.bn_args = bn_args - self.drop_connect_rate = drop_connect_rate - self.verbose = verbose - - # updated during build - self.in_chs = None - self.block_idx = 0 - self.block_count = 0 - - def _round_channels(self, chs): - return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) - - def _make_block(self, ba): - bt = ba.pop('block_type') - ba['in_chs'] = self.in_chs - ba['out_chs'] = self._round_channels(ba['out_chs']) - if 'fake_in_chs' in ba and ba['fake_in_chs']: - # FIXME this is a hack to work around mismatch in origin impl input filters - ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) - ba['bn_args'] = self.bn_args - ba['pad_type'] = self.pad_type - # block act fn overrides the model default - ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn - assert ba['act_fn'] is not None - if bt == 'ir': - ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count - ba['se_gate_fn'] = self.se_gate_fn - ba['se_reduce_mid'] = self.se_reduce_mid - if self.verbose: - logging.info(' InvertedResidual {}, Args: {}'.format(self.block_idx, str(ba))) - block = InvertedResidual(**ba) - elif bt == 'ds' or bt == 'dsa': - ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count - if self.verbose: - logging.info(' DepthwiseSeparable {}, Args: {}'.format(self.block_idx, str(ba))) - block = DepthwiseSeparableConv(**ba) - elif bt == 'er': - ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count - ba['se_gate_fn'] = self.se_gate_fn - ba['se_reduce_mid'] = self.se_reduce_mid - if self.verbose: - logging.info(' EdgeResidual {}, Args: {}'.format(self.block_idx, str(ba))) - block = EdgeResidual(**ba) - elif bt == 'cn': - if self.verbose: - logging.info(' ConvBnAct {}, Args: {}'.format(self.block_idx, str(ba))) - block = ConvBnAct(**ba) - else: - assert False, 'Uknkown block type (%s) while building model.' % bt - self.in_chs = ba['out_chs'] # update in_chs for arg of next block - - return block - - def _make_stack(self, stack_args): - blocks = [] - # each stack (stage) contains a list of block arguments - for i, ba in enumerate(stack_args): - if self.verbose: - logging.info(' Block: {}'.format(i)) - if i >= 1: - # only the first block in any stack can have a stride > 1 - ba['stride'] = 1 - block = self._make_block(ba) - blocks.append(block) - self.block_idx += 1 # incr global idx (across all stacks) - return nn.Sequential(*blocks) - - def __call__(self, in_chs, block_args): - """ Build the blocks - Args: - in_chs: Number of input-channels passed to first block - block_args: A list of lists, outer list defines stages, inner - list contains strings defining block configuration(s) - Return: - List of block stacks (each stack wrapped in nn.Sequential) - """ - if self.verbose: - logging.info('Building model trunk with %d stages...' % len(block_args)) - self.in_chs = in_chs - self.block_count = sum([len(x) for x in block_args]) - self.block_idx = 0 - blocks = [] - # outer list of block_args defines the stacks ('stages' by some conventions) - for stack_idx, stack in enumerate(block_args): - if self.verbose: - logging.info('Stack: {}'.format(stack_idx)) - assert isinstance(stack, list) - stack = self._make_stack(stack) - blocks.append(stack) - return blocks - - -def _initialize_weight_goog(m): - # weight init as per Tensorflow Official impl - # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # fan-out - m.weight.data.normal_(0, math.sqrt(2.0 / n)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - n = m.weight.size(0) # fan-out - init_range = 1.0 / math.sqrt(n) - m.weight.data.uniform_(-init_range, init_range) - m.bias.data.zero_() - - -def _initialize_weight_default(m): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') - - -def drop_connect(inputs, training=False, drop_connect_rate=0.): - """Apply drop connect.""" - if not training: - return inputs - - keep_prob = 1 - drop_connect_rate - random_tensor = keep_prob + torch.rand( - (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) - random_tensor.floor_() # binarize - output = inputs.div(keep_prob) * random_tensor - return output - - -class ChannelShuffle(nn.Module): - # FIXME haven't used yet - def __init__(self, groups): - super(ChannelShuffle, self).__init__() - self.groups = groups - - def forward(self, x): - """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" - N, C, H, W = x.size() - g = self.groups - assert C % g == 0, "Incompatible group size {} for input channel {}".format( - g, C - ) - return ( - x.view(N, g, int(C / g), H, W) - .permute(0, 2, 1, 3, 4) - .contiguous() - .view(N, C, H, W) - ) - - -class SqueezeExcite(nn.Module): - def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu, gate_fn=sigmoid): - super(SqueezeExcite, self).__init__() - self.act_fn = act_fn - self.gate_fn = gate_fn - reduced_chs = reduce_chs or in_chs - self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) - self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) - - def forward(self, x): - # NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance - x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) - x_se = self.conv_reduce(x_se) - x_se = self.act_fn(x_se, inplace=True) - x_se = self.conv_expand(x_se) - x = x * self.gate_fn(x_se) - return x - - -class ConvBnAct(nn.Module): - def __init__(self, in_chs, out_chs, kernel_size, - stride=1, pad_type='', act_fn=F.relu, bn_args=_BN_ARGS_PT): - super(ConvBnAct, self).__init__() - assert stride in [1, 2] - self.act_fn = act_fn - self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) - self.bn1 = nn.BatchNorm2d(out_chs, **bn_args) - - def forward(self, x): - x = self.conv(x) - x = self.bn1(x) - x = self.act_fn(x, inplace=True) - return x - - -class EdgeResidual(nn.Module): - """ Residual block with expansion convolution followed by pointwise-linear w/ stride""" - - def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, - stride=1, pad_type='', act_fn=F.relu, noskip=False, pw_kernel_size=1, - se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid, - bn_args=_BN_ARGS_PT, drop_connect_rate=0.): - super(EdgeResidual, self).__init__() - mid_chs = int(fake_in_chs * exp_ratio) if fake_in_chs > 0 else int(in_chs * exp_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. - self.has_residual = (in_chs == out_chs and stride == 1) and not noskip - self.act_fn = act_fn - self.drop_connect_rate = drop_connect_rate - - # Expansion convolution - self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) - self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args) - - # Squeeze-and-excitation - if self.has_se: - se_base_chs = mid_chs if se_reduce_mid else in_chs - self.se = SqueezeExcite( - mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) - - # Point-wise linear projection - self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) - self.bn2 = nn.BatchNorm2d(out_chs, **bn_args) - - def forward(self, x): - residual = x - - # Expansion convolution - x = self.conv_exp(x) - x = self.bn1(x) - x = self.act_fn(x, inplace=True) - - # Squeeze-and-excitation - if self.has_se: - x = self.se(x) - - # Point-wise linear projection - x = self.conv_pwl(x) - x = self.bn2(x) - - if self.has_residual: - if self.drop_connect_rate > 0.: - x = drop_connect(x, self.training, self.drop_connect_rate) - x += residual - - return x - - -class DepthwiseSeparableConv(nn.Module): - """ DepthwiseSeparable block - Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion - factor of 1.0. This is an alternative to having a IR with an optional first pw conv. - """ - def __init__(self, in_chs, out_chs, dw_kernel_size=3, - stride=1, pad_type='', act_fn=F.relu, noskip=False, - pw_kernel_size=1, pw_act=False, - se_ratio=0., se_gate_fn=sigmoid, - bn_args=_BN_ARGS_PT, drop_connect_rate=0.): - super(DepthwiseSeparableConv, self).__init__() - assert stride in [1, 2] - self.has_se = se_ratio is not None and se_ratio > 0. - self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip - self.has_pw_act = pw_act # activation after point-wise conv - self.act_fn = act_fn - self.drop_connect_rate = drop_connect_rate - - self.conv_dw = select_conv2d( - in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) - self.bn1 = nn.BatchNorm2d(in_chs, **bn_args) - - # Squeeze-and-excitation - if self.has_se: - self.se = SqueezeExcite( - in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) - - self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) - self.bn2 = nn.BatchNorm2d(out_chs, **bn_args) - - def forward(self, x): - residual = x - - x = self.conv_dw(x) - x = self.bn1(x) - x = self.act_fn(x, inplace=True) - - if self.has_se: - x = self.se(x) - - x = self.conv_pw(x) - x = self.bn2(x) - if self.has_pw_act: - x = self.act_fn(x, inplace=True) - - if self.has_residual: - if self.drop_connect_rate > 0.: - x = drop_connect(x, self.training, self.drop_connect_rate) - x += residual - return x - - -class InvertedResidual(nn.Module): - """ Inverted residual block w/ optional SE""" - - def __init__(self, in_chs, out_chs, dw_kernel_size=3, - stride=1, pad_type='', act_fn=F.relu, noskip=False, - exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, - se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid, - shuffle_type=None, bn_args=_BN_ARGS_PT, drop_connect_rate=0.): - super(InvertedResidual, self).__init__() - mid_chs = int(in_chs * exp_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. - self.has_residual = (in_chs == out_chs and stride == 1) and not noskip - self.act_fn = act_fn - self.drop_connect_rate = drop_connect_rate - - # Point-wise expansion - self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) - self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args) - - self.shuffle_type = shuffle_type - if shuffle_type is not None and isinstance(exp_kernel_size, list): - self.shuffle = ChannelShuffle(len(exp_kernel_size)) - - # Depth-wise convolution - self.conv_dw = select_conv2d( - mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) - self.bn2 = nn.BatchNorm2d(mid_chs, **bn_args) - - # Squeeze-and-excitation - if self.has_se: - se_base_chs = mid_chs if se_reduce_mid else in_chs - self.se = SqueezeExcite( - mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) - - # Point-wise linear projection - self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) - self.bn3 = nn.BatchNorm2d(out_chs, **bn_args) - - def forward(self, x): - residual = x - - # Point-wise expansion - x = self.conv_pw(x) - x = self.bn1(x) - x = self.act_fn(x, inplace=True) - - # FIXME haven't tried this yet - # for channel shuffle when using groups with pointwise convs as per FBNet variants - if self.shuffle_type == "mid": - x = self.shuffle(x) - - # Depth-wise convolution - x = self.conv_dw(x) - x = self.bn2(x) - x = self.act_fn(x, inplace=True) - - # Squeeze-and-excitation - if self.has_se: - x = self.se(x) - - # Point-wise linear projection - x = self.conv_pwl(x) - x = self.bn3(x) - - if self.has_residual: - if self.drop_connect_rate > 0.: - x = drop_connect(x, self.training, self.drop_connect_rate) - x += residual - - # NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants - - return x - - -class GenEfficientNet(nn.Module): - """ Generic EfficientNet - - An implementation of efficient network architectures, in many cases mobile optimized networks: - * MobileNet-V1 - * MobileNet-V2 - * MobileNet-V3 - * MnasNet A1, B1, and small - * FBNet A, B, and C - * ChamNet (arch details are murky) - * Single-Path NAS Pixel1 - * EfficientNet B0-B7 - * MixNet S, M, L - """ - - def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, - channel_multiplier=1.0, channel_divisor=8, channel_min=None, - pad_type='', act_fn=F.relu, drop_rate=0., drop_connect_rate=0., - se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=_BN_ARGS_PT, - global_pool='avg', head_conv='default', weight_init='goog'): - super(GenEfficientNet, self).__init__() - self.num_classes = num_classes - self.drop_rate = drop_rate - self.act_fn = act_fn - self.num_features = num_features - - stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) - self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) - self.bn1 = nn.BatchNorm2d(stem_size, **bn_args) - in_chs = stem_size - - builder = _BlockBuilder( - channel_multiplier, channel_divisor, channel_min, - pad_type, act_fn, se_gate_fn, se_reduce_mid, - bn_args, drop_connect_rate, verbose=_DEBUG) - self.blocks = nn.Sequential(*builder(in_chs, block_args)) - in_chs = builder.in_chs - - if not head_conv or head_conv == 'none': - self.efficient_head = False - self.conv_head = None - assert in_chs == self.num_features - else: - self.efficient_head = head_conv == 'efficient' - self.conv_head = select_conv2d(in_chs, self.num_features, 1, padding=pad_type) - self.bn2 = None if self.efficient_head else nn.BatchNorm2d(self.num_features, **bn_args) - - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes) - - for m in self.modules(): - if weight_init == 'goog': - _initialize_weight_goog(m) - else: - _initialize_weight_default(m) - - def get_classifier(self): - return self.classifier - - def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.num_classes = num_classes - del self.classifier - if num_classes: - self.classifier = nn.Linear( - self.num_features * self.global_pool.feat_mult(), num_classes) - else: - self.classifier = None - - def forward_features(self, x, pool=True): - x = self.conv_stem(x) - x = self.bn1(x) - x = self.act_fn(x, inplace=True) - x = self.blocks(x) - if self.efficient_head: - # efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv - x = self.global_pool(x) # always need to pool here regardless of flag - x = self.conv_head(x) - # no BN - x = self.act_fn(x, inplace=True) - if pool: - # expect flattened output if pool is true, otherwise keep dim - x = x.view(x.size(0), -1) - else: - if self.conv_head is not None: - x = self.conv_head(x) - x = self.bn2(x) - x = self.act_fn(x, inplace=True) - if pool: - x = self.global_pool(x) - x = x.view(x.size(0), -1) - return x - - def forward(self, x): - x = self.forward_features(x) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) - return self.classifier(x) - - -def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs): - """Creates a mnasnet-a1 model. - - Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet - Paper: https://arxiv.org/pdf/1807.11626.pdf. - - Args: - channel_multiplier: multiplier to number of channels per layer. - """ - arch_def = [ - # stage 0, 112x112 in - ['ds_r1_k3_s1_e1_c16_noskip'], - # stage 1, 112x112 in - ['ir_r2_k3_s2_e6_c24'], - # stage 2, 56x56 in - ['ir_r3_k5_s2_e3_c40_se0.25'], - # stage 3, 28x28 in - ['ir_r4_k3_s2_e6_c80'], - # stage 4, 14x14in - ['ir_r2_k3_s1_e6_c112_se0.25'], - # stage 5, 14x14in - ['ir_r3_k5_s2_e6_c160_se0.25'], - # stage 6, 7x7 in - ['ir_r1_k3_s1_e6_c320'], - ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=32, - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - **kwargs - ) - return model - - -def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs): - """Creates a mnasnet-b1 model. - - Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet - Paper: https://arxiv.org/pdf/1807.11626.pdf. - - Args: - channel_multiplier: multiplier to number of channels per layer. - """ - arch_def = [ - # stage 0, 112x112 in - ['ds_r1_k3_s1_c16_noskip'], - # stage 1, 112x112 in - ['ir_r3_k3_s2_e3_c24'], - # stage 2, 56x56 in - ['ir_r3_k5_s2_e3_c40'], - # stage 3, 28x28 in - ['ir_r3_k5_s2_e6_c80'], - # stage 4, 14x14in - ['ir_r2_k3_s1_e6_c96'], - # stage 5, 14x14in - ['ir_r4_k5_s2_e6_c192'], - # stage 6, 7x7 in - ['ir_r1_k3_s1_e6_c320_noskip'] - ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=32, - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - **kwargs - ) - return model - - -def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs): - """Creates a mnasnet-b1 model. - - Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet - Paper: https://arxiv.org/pdf/1807.11626.pdf. - - Args: - channel_multiplier: multiplier to number of channels per layer. - """ - arch_def = [ - ['ds_r1_k3_s1_c8'], - ['ir_r1_k3_s2_e3_c16'], - ['ir_r2_k3_s2_e6_c16'], - ['ir_r4_k5_s2_e6_c32_se0.25'], - ['ir_r3_k3_s1_e6_c32_se0.25'], - ['ir_r3_k5_s2_e6_c88_se0.25'], - ['ir_r1_k3_s1_e6_c144'] - ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=8, - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - **kwargs - ) - return model - - -def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs): - """ Generate MobileNet-V1 network - Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py - Paper: https://arxiv.org/abs/1801.04381 - """ - arch_def = [ - ['dsa_r1_k3_s1_c64'], - ['dsa_r2_k3_s2_c128'], - ['dsa_r2_k3_s2_c256'], - ['dsa_r6_k3_s2_c512'], - ['dsa_r2_k3_s2_c1024'], - ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=32, - num_features=1024, - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - act_fn=F.relu6, - head_conv='none', - **kwargs - ) - return model - - -def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs): - """ Generate MobileNet-V2 network - Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py - Paper: https://arxiv.org/abs/1801.04381 - """ - arch_def = [ - ['ds_r1_k3_s1_c16'], - ['ir_r2_k3_s2_e6_c24'], - ['ir_r3_k3_s2_e6_c32'], - ['ir_r4_k3_s2_e6_c64'], - ['ir_r3_k3_s1_e6_c96'], - ['ir_r3_k3_s2_e6_c160'], - ['ir_r1_k3_s1_e6_c320'], - ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=32, - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - act_fn=F.relu6, - **kwargs - ) - return model - - -def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs): - """Creates a MobileNet-V3 model. - - Ref impl: ? - Paper: https://arxiv.org/abs/1905.02244 - - Args: - channel_multiplier: multiplier to number of channels per layer. - """ - arch_def = [ - # stage 0, 112x112 in - ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu - # stage 1, 112x112 in - ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu - # stage 2, 56x56 in - ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu - # stage 3, 28x28 in - ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish - # stage 4, 14x14in - ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish - # stage 5, 14x14in - ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish - # stage 6, 7x7 in - ['cn_r1_k1_s1_c960'], # hard-swish - ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=16, - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - act_fn=hard_swish, - se_gate_fn=hard_sigmoid, - se_reduce_mid=True, - head_conv='efficient', - **kwargs - ) - return model - - -def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs): - """ Generate Chameleon Network (ChamNet) - - Paper: https://arxiv.org/abs/1812.08934 - Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py - - FIXME: this a bit of an educated guess based on trunkd def in maskrcnn_benchmark - """ - arch_def = [ - ['ir_r1_k3_s1_e1_c24'], - ['ir_r2_k7_s2_e4_c48'], - ['ir_r5_k3_s2_e7_c64'], - ['ir_r7_k5_s2_e12_c56'], - ['ir_r5_k3_s1_e8_c88'], - ['ir_r4_k3_s2_e7_c152'], - ['ir_r1_k3_s1_e10_c104'], - ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=32, - num_features=1280, # no idea what this is? try mobile/mnasnet default? - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - **kwargs - ) - return model - - -def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs): - """ Generate Chameleon Network (ChamNet) - - Paper: https://arxiv.org/abs/1812.08934 - Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py - - FIXME: this a bit of an educated guess based on trunk def in maskrcnn_benchmark - """ - arch_def = [ - ['ir_r1_k3_s1_e1_c24'], - ['ir_r4_k5_s2_e8_c32'], - ['ir_r6_k7_s2_e5_c48'], - ['ir_r3_k5_s2_e9_c56'], - ['ir_r6_k3_s1_e6_c56'], - ['ir_r6_k3_s2_e2_c152'], - ['ir_r1_k3_s1_e6_c112'], - ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=32, - num_features=1280, # no idea what this is? try mobile/mnasnet default? - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - **kwargs - ) - return model - - -def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs): - """ FBNet-C - - Paper: https://arxiv.org/abs/1812.03443 - Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py - - NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, - it was used to confirm some building block details - """ - arch_def = [ - ['ir_r1_k3_s1_e1_c16'], - ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], - ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], - ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], - ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], - ['ir_r4_k5_s2_e6_c184'], - ['ir_r1_k3_s1_e6_c352'], - ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=16, - num_features=1984, # paper suggests this, but is not 100% clear - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - **kwargs - ) - return model - - -def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs): - """Creates the Single-Path NAS model from search targeted for Pixel1 phone. - - Paper: https://arxiv.org/abs/1904.02877 - - Args: - channel_multiplier: multiplier to number of channels per layer. - """ - arch_def = [ - # stage 0, 112x112 in - ['ds_r1_k3_s1_c16_noskip'], - # stage 1, 112x112 in - ['ir_r3_k3_s2_e3_c24'], - # stage 2, 56x56 in - ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], - # stage 3, 28x28 in - ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], - # stage 4, 14x14in - ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], - # stage 5, 14x14in - ['ir_r4_k5_s2_e6_c192'], - # stage 6, 7x7 in - ['ir_r1_k3_s1_e6_c320_noskip'] - ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=32, - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - **kwargs - ) - return model - - -def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs): - """Creates an EfficientNet model. - - Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py - Paper: https://arxiv.org/abs/1905.11946 - - EfficientNet params - name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) - 'efficientnet-b0': (1.0, 1.0, 224, 0.2), - 'efficientnet-b1': (1.0, 1.1, 240, 0.2), - 'efficientnet-b2': (1.1, 1.2, 260, 0.3), - 'efficientnet-b3': (1.2, 1.4, 300, 0.3), - 'efficientnet-b4': (1.4, 1.8, 380, 0.4), - 'efficientnet-b5': (1.6, 2.2, 456, 0.4), - 'efficientnet-b6': (1.8, 2.6, 528, 0.5), - 'efficientnet-b7': (2.0, 3.1, 600, 0.5), - - Args: - channel_multiplier: multiplier to number of channels per layer - depth_multiplier: multiplier to number of repeats per stage - - """ - arch_def = [ - ['ds_r1_k3_s1_e1_c16_se0.25'], - ['ir_r2_k3_s2_e6_c24_se0.25'], - ['ir_r2_k5_s2_e6_c40_se0.25'], - ['ir_r3_k3_s2_e6_c80_se0.25'], - ['ir_r3_k5_s1_e6_c112_se0.25'], - ['ir_r4_k5_s2_e6_c192_se0.25'], - ['ir_r1_k3_s1_e6_c320_se0.25'], - ] - num_features = _round_channels(1280, channel_multiplier, 8, None) - model = GenEfficientNet( - _decode_arch_def(arch_def, depth_multiplier), - num_classes=num_classes, - stem_size=32, - channel_multiplier=channel_multiplier, - num_features=num_features, - bn_args=_resolve_bn_args(kwargs), - act_fn=swish, - **kwargs - ) - return model - - -def _gen_efficientnet_edge(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs): - arch_def = [ - # NOTE `fc` is present to override a mismatch between stem channels and in chs not - # present in other models - ['er_r1_k3_s1_e4_c24_fc24_noskip'], - ['er_r2_k3_s2_e8_c32'], - ['er_r4_k3_s2_e8_c48'], - ['ir_r5_k5_s2_e8_c96'], - ['ir_r4_k5_s1_e8_c144'], - ['ir_r2_k5_s2_e8_c192'], - ] - num_features = _round_channels(1280, channel_multiplier, 8, None) - model = GenEfficientNet( - _decode_arch_def(arch_def, depth_multiplier), - num_classes=num_classes, - stem_size=32, - channel_multiplier=channel_multiplier, - num_features=num_features, - bn_args=_resolve_bn_args(kwargs), - act_fn=F.relu, - **kwargs - ) - return model - - -def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs): - """Creates a MixNet Small model. - - Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet - Paper: https://arxiv.org/abs/1907.09595 - """ - arch_def = [ - # stage 0, 112x112 in - ['ds_r1_k3_s1_e1_c16'], # relu - # stage 1, 112x112 in - ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu - # stage 2, 56x56 in - ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish - # stage 3, 28x28 in - ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish - # stage 4, 14x14in - ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish - # stage 5, 14x14in - ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish - # 7x7 - ] - model = GenEfficientNet( - _decode_arch_def(arch_def), - num_classes=num_classes, - stem_size=16, - num_features=1536, - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - act_fn=F.relu, - **kwargs - ) - return model - - -def _gen_mixnet_m(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs): - """Creates a MixNet Medium-Large model. - - Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet - Paper: https://arxiv.org/abs/1907.09595 - """ - arch_def = [ - # stage 0, 112x112 in - ['ds_r1_k3_s1_e1_c24'], # relu - # stage 1, 112x112 in - ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu - # stage 2, 56x56 in - ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish - # stage 3, 28x28 in - ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish - # stage 4, 14x14in - ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish - # stage 5, 14x14in - ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish - # 7x7 - ] - model = GenEfficientNet( - _decode_arch_def(arch_def, depth_multiplier=depth_multiplier, depth_trunc='round'), - num_classes=num_classes, - stem_size=24, - num_features=1536, - channel_multiplier=channel_multiplier, - bn_args=_resolve_bn_args(kwargs), - act_fn=F.relu, - **kwargs - ) - return model - - -@register_model -def mnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet B1, depth multiplier of 0.5. """ - default_cfg = default_cfgs['mnasnet_050'] - model = _gen_mnasnet_b1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet B1, depth multiplier of 0.75. """ - default_cfg = default_cfgs['mnasnet_075'] - model = _gen_mnasnet_b1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet B1, depth multiplier of 1.0. """ - default_cfg = default_cfgs['mnasnet_100'] - model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mnasnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet B1, depth multiplier of 1.0. """ - return mnasnet_100(pretrained, num_classes, in_chans, **kwargs) - - -@register_model -def mnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet B1, depth multiplier of 1.4 """ - default_cfg = default_cfgs['mnasnet_140'] - model = _gen_mnasnet_b1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def semnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ - default_cfg = default_cfgs['semnasnet_050'] - model = _gen_mnasnet_a1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def semnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ - default_cfg = default_cfgs['semnasnet_075'] - model = _gen_mnasnet_a1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ - default_cfg = default_cfgs['semnasnet_100'] - model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mnasnet_a1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ - return semnasnet_100(pretrained, num_classes, in_chans, **kwargs) - - -@register_model -def semnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ - default_cfg = default_cfgs['semnasnet_140'] - model = _gen_mnasnet_a1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mnasnet_small(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet Small, depth multiplier of 1.0. """ - default_cfg = default_cfgs['mnasnet_small'] - model = _gen_mnasnet_small(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mobilenetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MobileNet V1 """ - default_cfg = default_cfgs['mobilenetv1_100'] - model = _gen_mobilenet_v1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mobilenetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MobileNet V2 """ - default_cfg = default_cfgs['mobilenetv2_100'] - model = _gen_mobilenet_v2(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mobilenetv3_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MobileNet V3 """ - default_cfg = default_cfgs['mobilenetv3_050'] - model = _gen_mobilenet_v3(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mobilenetv3_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MobileNet V3 """ - default_cfg = default_cfgs['mobilenetv3_075'] - model = _gen_mobilenet_v3(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mobilenetv3_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MobileNet V3 """ - default_cfg = default_cfgs['mobilenetv3_100'] - if pretrained: - # pretrained model trained with non-default BN epsilon - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - model = _gen_mobilenet_v3(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def fbnetc_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ FBNet-C """ - default_cfg = default_cfgs['fbnetc_100'] - if pretrained: - # pretrained model trained with non-default BN epsilon - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - model = _gen_fbnetc(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def chamnetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ ChamNet """ - default_cfg = default_cfgs['chamnetv1_100'] - model = _gen_chamnet_v1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def chamnetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ ChamNet """ - default_cfg = default_cfgs['chamnetv2_100'] - model = _gen_chamnet_v2(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def spnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ Single-Path NAS Pixel1""" - default_cfg = default_cfgs['spnasnet_100'] - model = _gen_spnasnet(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B0 """ - default_cfg = default_cfgs['efficientnet_b0'] - # NOTE for train, drop_rate should be 0.2 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - model = _gen_efficientnet( - channel_multiplier=1.0, depth_multiplier=1.0, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B1 """ - default_cfg = default_cfgs['efficientnet_b1'] - # NOTE for train, drop_rate should be 0.2 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - model = _gen_efficientnet( - channel_multiplier=1.0, depth_multiplier=1.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B2 """ - default_cfg = default_cfgs['efficientnet_b2'] - # NOTE for train, drop_rate should be 0.3 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - model = _gen_efficientnet( - channel_multiplier=1.1, depth_multiplier=1.2, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B3 """ - default_cfg = default_cfgs['efficientnet_b3'] - # NOTE for train, drop_rate should be 0.3 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - model = _gen_efficientnet( - channel_multiplier=1.2, depth_multiplier=1.4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B4 """ - default_cfg = default_cfgs['efficientnet_b4'] - # NOTE for train, drop_rate should be 0.4 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - model = _gen_efficientnet( - channel_multiplier=1.4, depth_multiplier=1.8, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B5 """ - # NOTE for train, drop_rate should be 0.4 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - default_cfg = default_cfgs['efficientnet_b5'] - model = _gen_efficientnet( - channel_multiplier=1.6, depth_multiplier=2.2, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B6 """ - # NOTE for train, drop_rate should be 0.5 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - default_cfg = default_cfgs['efficientnet_b6'] - model = _gen_efficientnet( - channel_multiplier=1.8, depth_multiplier=2.6, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B7 """ - # NOTE for train, drop_rate should be 0.5 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - default_cfg = default_cfgs['efficientnet_b7'] - model = _gen_efficientnet( - channel_multiplier=2.0, depth_multiplier=3.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-Edge Small. """ - default_cfg = default_cfgs['efficientnet_es'] - model = _gen_efficientnet_edge( - channel_multiplier=1.0, depth_multiplier=1.0, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-Edge-Medium. """ - default_cfg = default_cfgs['efficientnet_em'] - model = _gen_efficientnet_edge( - channel_multiplier=1.0, depth_multiplier=1.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-Edge-Large. """ - default_cfg = default_cfgs['efficientnet_el'] - model = _gen_efficientnet_edge( - channel_multiplier=1.2, depth_multiplier=1.4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B0. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b0'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - channel_multiplier=1.0, depth_multiplier=1.0, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B1. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b1'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - channel_multiplier=1.0, depth_multiplier=1.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B2. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b2'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - channel_multiplier=1.1, depth_multiplier=1.2, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B3. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b3'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - channel_multiplier=1.2, depth_multiplier=1.4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B4. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b4'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - channel_multiplier=1.4, depth_multiplier=1.8, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B5. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_b5'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - channel_multiplier=1.6, depth_multiplier=2.2, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B6. Tensorflow compatible variant """ - # NOTE for train, drop_rate should be 0.5 - default_cfg = default_cfgs['tf_efficientnet_b6'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - channel_multiplier=1.8, depth_multiplier=2.6, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-B7. Tensorflow compatible variant """ - # NOTE for train, drop_rate should be 0.5 - default_cfg = default_cfgs['tf_efficientnet_b7'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - channel_multiplier=2.0, depth_multiplier=3.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-Edge Small. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_es'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet_edge( - channel_multiplier=1.0, depth_multiplier=1.0, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_em'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet_edge( - channel_multiplier=1.0, depth_multiplier=1.1, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ EfficientNet-Edge-Large. Tensorflow compatible variant """ - default_cfg = default_cfgs['tf_efficientnet_el'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet_edge( - channel_multiplier=1.2, depth_multiplier=1.4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Creates a MixNet Small model. - """ - default_cfg = default_cfgs['mixnet_s'] - model = _gen_mixnet_s( - channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mixnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Creates a MixNet Medium model. - """ - default_cfg = default_cfgs['mixnet_m'] - model = _gen_mixnet_m( - channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Creates a MixNet Large model. - """ - default_cfg = default_cfgs['mixnet_l'] - model = _gen_mixnet_m( - channel_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mixnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Creates a MixNet Extra-Large model. - Not a paper spec, experimental def by RW w/ depth scaling. - """ - default_cfg = default_cfgs['mixnet_xl'] - #kwargs['drop_connect_rate'] = 0.2 - model = _gen_mixnet_m( - channel_multiplier=1.6, depth_multiplier=1.2, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def mixnet_xxl(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Creates a MixNet Double Extra Large model. - Not a paper spec, experimental def by RW w/ depth scaling. - """ - default_cfg = default_cfgs['mixnet_xxl'] - # kwargs['drop_connect_rate'] = 0.2 - model = _gen_mixnet_m( - channel_multiplier=2.4, depth_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Creates a MixNet Small model. Tensorflow compatible variant - """ - default_cfg = default_cfgs['tf_mixnet_s'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_mixnet_s( - channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_mixnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Creates a MixNet Medium model. Tensorflow compatible variant - """ - default_cfg = default_cfgs['tf_mixnet_m'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_mixnet_m( - channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -@register_model -def tf_mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Creates a MixNet Large model. Tensorflow compatible variant - """ - default_cfg = default_cfgs['tf_mixnet_l'] - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_mixnet_m( - channel_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def gen_efficientnet_model_names(): - return set(_models) diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index 715e0950..3d0f926f 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -11,11 +11,9 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD - -__all__ = ['GluonResNet'] +from .resnet import ResNet, Bottleneck, BasicBlock def _cfg(url='', **kwargs): @@ -57,312 +55,12 @@ default_cfgs = { } -def _get_padding(kernel_size, stride, dilation=1): - padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 - return padding - - -class SEModule(nn.Module): - - def __init__(self, channels, reduction_channels): - super(SEModule, self).__init__() - #self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc1 = nn.Conv2d( - channels, reduction_channels, kernel_size=1, padding=0, bias=True) - self.relu = nn.ReLU() - self.fc2 = nn.Conv2d( - reduction_channels, channels, kernel_size=1, padding=0, bias=True) - self.sigmoid = nn.Sigmoid() - - def forward(self, x): - module_input = x - #x = self.avg_pool(x) - x = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.sigmoid(x) - return module_input * x - - -class BasicBlockGl(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d): - super(BasicBlockGl, self).__init__() - - assert cardinality == 1, 'BasicBlock only supports cardinality of 1' - assert base_width == 64, 'BasicBlock doest not support changing base width' - first_planes = planes // reduce_first - outplanes = planes * self.expansion - - self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) - self.bn1 = norm_layer(first_planes) - self.relu = nn.ReLU() - self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=previous_dilation, - dilation=previous_dilation, bias=False) - self.bn2 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.downsample = downsample - self.stride = stride - self.dilation = dilation - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - out = self.conv2(out) - out = self.bn2(out) - - if self.se is not None: - out = self.se(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class BottleneckGl(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d): - super(BottleneckGl, self).__init__() - - width = int(math.floor(planes * (base_width / 64)) * cardinality) - first_planes = width // reduce_first - outplanes = planes * self.expansion - - self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) - self.bn1 = norm_layer(first_planes) - self.conv2 = nn.Conv2d( - first_planes, width, kernel_size=3, stride=stride, - padding=dilation, dilation=dilation, groups=cardinality, bias=False) - self.bn2 = norm_layer(width) - self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) - self.bn3 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.relu = nn.ReLU() - self.downsample = downsample - self.stride = stride - self.dilation = dilation - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.se is not None: - out = self.se(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class GluonResNet(nn.Module): - """ Gluon ResNet (https://gluon-cv.mxnet.io/model_zoo/classification.html) - This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet found in the gluon model zoo that - * have stride in 3x3 conv layer of bottleneck - * have conv-bn-act ordering - - Included ResNet variants are: - * v1b - 7x7 stem, stem_width=64, same as torchvision ResNet (checkpoint compatible), or NVIDIA ResNet 'v1.5' - * v1c - 3 layer deep 3x3 stem, stem_width = 32 - * v1d - 3 layer deep 3x3 stem, stem_width = 32, average pool in downsample - * v1e - 3 layer deep 3x3 stem, stem_width = 64, average pool in downsample *no pretrained weights available - * v1s - 3 layer deep 3x3 stem, stem_width = 64 - - ResNeXt is standard and checkpoint compatible with torchvision pretrained models. 7x7 stem, - stem_width = 64, standard cardinality and base width calcs - - SE-ResNeXt is standard. 7x7 stem, stem_width = 64, - checkpoints are not compatible with Cadene pretrained, but could be with key mapping - - SENet-154 is standard. 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, - reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block - - Original ResNet-V1, ResNet-V2 (bn-act-conv), and SE-ResNet (stride in first bottleneck conv) are NOT supported. - They do have Gluon pretrained weights but are, at best, comparable (or inferior) to the supported models. - - Parameters - ---------- - block : Block - Class for the residual block. Options are BasicBlockGl, BottleneckGl. - layers : list of int - Numbers of layers in each block - num_classes : int, default 1000 - Number of classification classes. - deep_stem : bool, default False - Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. - block_reduce_first: int, default 1 - Reduction factor for first convolution output width of residual blocks, - 1 for all archs except senets, where 2 - down_kernel_size: int, default 1 - Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets - avg_down : bool, default False - Whether to use average pooling for projection skip connection between stages/downsample. - dilated : bool, default False - Applying dilation strategy to pretrained ResNet yielding a stride-8 model, - typically used in Semantic Segmentation. - """ - def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, - cardinality=1, base_width=64, stem_width=64, deep_stem=False, - block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False, - norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg'): - self.num_classes = num_classes - self.inplanes = stem_width * 2 if deep_stem else 64 - self.cardinality = cardinality - self.base_width = base_width - self.drop_rate = drop_rate - self.expansion = block.expansion - self.dilated = dilated - super(GluonResNet, self).__init__() - - if not deep_stem: - self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=7, stride=2, padding=3, bias=False) - else: - conv1_modules = [ - nn.Conv2d(in_chans, stem_width, 3, stride=2, padding=1, bias=False), - norm_layer(stem_width), - nn.ReLU(), - nn.Conv2d(stem_width, stem_width, 3, stride=1, padding=1, bias=False), - norm_layer(stem_width), - nn.ReLU(), - nn.Conv2d(stem_width, self.inplanes, 3, stride=1, padding=1, bias=False), - ] - self.conv1 = nn.Sequential(*conv1_modules) - self.bn1 = norm_layer(self.inplanes) - self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - stride_3_4 = 1 if self.dilated else 2 - dilation_3 = 2 if self.dilated else 1 - dilation_4 = 4 if self.dilated else 1 - self.layer1 = self._make_layer( - block, 64, layers[0], stride=1, reduce_first=block_reduce_first, - use_se=use_se, avg_down=avg_down, down_kernel_size=1, norm_layer=norm_layer) - self.layer2 = self._make_layer( - block, 128, layers[1], stride=2, reduce_first=block_reduce_first, - use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer) - self.layer3 = self._make_layer( - block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, reduce_first=block_reduce_first, - use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer) - self.layer4 = self._make_layer( - block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, reduce_first=block_reduce_first, - use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer) - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.num_features = 512 * block.expansion - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1.) - nn.init.constant_(m.bias, 0.) - - def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, - use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample_padding = _get_padding(down_kernel_size, stride) - if avg_down: - avg_stride = stride if dilation == 1 else 1 - downsample_layers = [ - nn.AvgPool2d(avg_stride, avg_stride, ceil_mode=True, count_include_pad=False), - nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size, - stride=1, padding=downsample_padding, bias=False), - norm_layer(planes * block.expansion), - ] - else: - downsample_layers = [ - nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size, - stride=stride, padding=downsample_padding, bias=False), - norm_layer(planes * block.expansion), - ] - downsample = nn.Sequential(*downsample_layers) - - first_dilation = 1 if dilation in (1, 2) else 2 - layers = [block( - self.inplanes, planes, stride, downsample, - cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - use_se=use_se, dilation=first_dilation, previous_dilation=dilation, norm_layer=norm_layer)] - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block( - self.inplanes, planes, - cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - use_se=use_se, dilation=dilation, previous_dilation=dilation, norm_layer=norm_layer)) - - return nn.Sequential(*layers) - - def get_classifier(self): - return self.fc - - def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.num_classes = num_classes - del self.fc - if num_classes: - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - else: - self.fc = None - - def forward_features(self, x, pool=True): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - if pool: - x = self.global_pool(x) - x = x.view(x.size(0), -1) - return x - - def forward(self, x): - x = self.forward_features(x) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.fc(x) - return x - - @register_model def gluon_resnet18_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-18 model. """ default_cfg = default_cfgs['gluon_resnet18_v1b'] - model = GluonResNet(BasicBlockGl, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) + model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -374,7 +72,7 @@ def gluon_resnet34_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """Constructs a ResNet-34 model. """ default_cfg = default_cfgs['gluon_resnet34_v1b'] - model = GluonResNet(BasicBlockGl, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -386,7 +84,7 @@ def gluon_resnet50_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """Constructs a ResNet-50 model. """ default_cfg = default_cfgs['gluon_resnet50_v1b'] - model = GluonResNet(BottleneckGl, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -398,7 +96,7 @@ def gluon_resnet101_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs """Constructs a ResNet-101 model. """ default_cfg = default_cfgs['gluon_resnet101_v1b'] - model = GluonResNet(BottleneckGl, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -410,7 +108,7 @@ def gluon_resnet152_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs """Constructs a ResNet-152 model. """ default_cfg = default_cfgs['gluon_resnet152_v1b'] - model = GluonResNet(BottleneckGl, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -422,8 +120,8 @@ def gluon_resnet50_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """Constructs a ResNet-50 model. """ default_cfg = default_cfgs['gluon_resnet50_v1c'] - model = GluonResNet(BottleneckGl, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=32, deep_stem=True, **kwargs) + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=32, deep_stem=True, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -435,8 +133,8 @@ def gluon_resnet101_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs """Constructs a ResNet-101 model. """ default_cfg = default_cfgs['gluon_resnet101_v1c'] - model = GluonResNet(BottleneckGl, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=32, deep_stem=True, **kwargs) + model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=32, deep_stem=True, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -448,8 +146,8 @@ def gluon_resnet152_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs """Constructs a ResNet-152 model. """ default_cfg = default_cfgs['gluon_resnet152_v1c'] - model = GluonResNet(BottleneckGl, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=32, deep_stem=True, **kwargs) + model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=32, deep_stem=True, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -461,8 +159,8 @@ def gluon_resnet50_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """Constructs a ResNet-50 model. """ default_cfg = default_cfgs['gluon_resnet50_v1d'] - model = GluonResNet(BottleneckGl, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=32, deep_stem=True, avg_down=True, **kwargs) + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=32, deep_stem=True, avg_down=True, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -474,8 +172,8 @@ def gluon_resnet101_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs """Constructs a ResNet-101 model. """ default_cfg = default_cfgs['gluon_resnet101_v1d'] - model = GluonResNet(BottleneckGl, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=32, deep_stem=True, avg_down=True, **kwargs) + model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=32, deep_stem=True, avg_down=True, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -487,8 +185,8 @@ def gluon_resnet152_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs """Constructs a ResNet-152 model. """ default_cfg = default_cfgs['gluon_resnet152_v1d'] - model = GluonResNet(BottleneckGl, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=32, deep_stem=True, avg_down=True, **kwargs) + model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=32, deep_stem=True, avg_down=True, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -500,8 +198,8 @@ def gluon_resnet50_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """Constructs a ResNet-50-V1e model. No pretrained weights for any 'e' variants """ default_cfg = default_cfgs['gluon_resnet50_v1e'] - model = GluonResNet(BottleneckGl, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=64, deep_stem=True, avg_down=True, **kwargs) + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=64, deep_stem=True, avg_down=True, **kwargs) model.default_cfg = default_cfg #if pretrained: # load_pretrained(model, default_cfg, num_classes, in_chans) @@ -513,8 +211,8 @@ def gluon_resnet101_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs """Constructs a ResNet-101 model. """ default_cfg = default_cfgs['gluon_resnet101_v1e'] - model = GluonResNet(BottleneckGl, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=64, deep_stem=True, avg_down=True, **kwargs) + model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=64, deep_stem=True, avg_down=True, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -526,8 +224,8 @@ def gluon_resnet152_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs """Constructs a ResNet-152 model. """ default_cfg = default_cfgs['gluon_resnet152_v1e'] - model = GluonResNet(BottleneckGl, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=64, deep_stem=True, avg_down=True, **kwargs) + model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=64, deep_stem=True, avg_down=True, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -539,8 +237,8 @@ def gluon_resnet50_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """Constructs a ResNet-50 model. """ default_cfg = default_cfgs['gluon_resnet50_v1s'] - model = GluonResNet(BottleneckGl, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=64, deep_stem=True, **kwargs) + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=64, deep_stem=True, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -552,8 +250,8 @@ def gluon_resnet101_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs """Constructs a ResNet-101 model. """ default_cfg = default_cfgs['gluon_resnet101_v1s'] - model = GluonResNet(BottleneckGl, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=64, deep_stem=True, **kwargs) + model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=64, deep_stem=True, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -565,8 +263,8 @@ def gluon_resnet152_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs """Constructs a ResNet-152 model. """ default_cfg = default_cfgs['gluon_resnet152_v1s'] - model = GluonResNet(BottleneckGl, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, - stem_width=64, deep_stem=True, **kwargs) + model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, + stem_width=64, deep_stem=True, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -578,8 +276,8 @@ def gluon_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwar """Constructs a ResNeXt50-32x4d model. """ default_cfg = default_cfgs['gluon_resnext50_32x4d'] - model = GluonResNet( - BottleneckGl, [3, 4, 6, 3], cardinality=32, base_width=4, + model = ResNet( + Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -592,8 +290,8 @@ def gluon_resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa """Constructs a ResNeXt-101 model. """ default_cfg = default_cfgs['gluon_resnext101_32x4d'] - model = GluonResNet( - BottleneckGl, [3, 4, 23, 3], cardinality=32, base_width=4, + model = ResNet( + Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -606,8 +304,8 @@ def gluon_resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa """Constructs a ResNeXt-101 model. """ default_cfg = default_cfgs['gluon_resnext101_64x4d'] - model = GluonResNet( - BottleneckGl, [3, 4, 23, 3], cardinality=64, base_width=4, + model = ResNet( + Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -620,8 +318,8 @@ def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kw """Constructs a SEResNeXt50-32x4d model. """ default_cfg = default_cfgs['gluon_seresnext50_32x4d'] - model = GluonResNet( - BottleneckGl, [3, 4, 6, 3], cardinality=32, base_width=4, use_se=True, + model = ResNet( + Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, use_se=True, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -634,8 +332,8 @@ def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k """Constructs a SEResNeXt-101-32x4d model. """ default_cfg = default_cfgs['gluon_seresnext101_32x4d'] - model = GluonResNet( - BottleneckGl, [3, 4, 23, 3], cardinality=32, base_width=4, use_se=True, + model = ResNet( + Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, use_se=True, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -648,8 +346,8 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k """Constructs a SEResNeXt-101-64x4d model. """ default_cfg = default_cfgs['gluon_seresnext101_64x4d'] - model = GluonResNet( - BottleneckGl, [3, 4, 23, 3], cardinality=64, base_width=4, use_se=True, + model = ResNet( + Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, use_se=True, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -662,8 +360,8 @@ def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs an SENet-154 model. """ default_cfg = default_cfgs['gluon_senet154'] - model = GluonResNet( - BottleneckGl, [3, 8, 36, 3], cardinality=64, base_width=4, use_se=True, + model = ResNet( + Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, use_se=True, deep_stem=True, down_kernel_size=3, block_reduce_first=2, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 9ac728da..7460f4a2 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -57,7 +57,7 @@ def resume_checkpoint(model, checkpoint_path): raise FileNotFoundError() -def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None): +def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True): if cfg is None: cfg = getattr(model, 'default_cfg') if cfg is None or 'url' not in cfg or not cfg['url']: @@ -74,7 +74,6 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non elif in_chans != 3: assert False, "Invalid in_chans for pretrained weights" - strict = True classifier_name = cfg['classifier'] if num_classes == 1000 and cfg['num_classes'] == 1001: # special case for imagenet trained models with extra background class in pretrained weights diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py new file mode 100644 index 00000000..59ded4ab --- /dev/null +++ b/timm/models/hrnet.py @@ -0,0 +1,869 @@ +""" HRNet + +Copied from https://github.com/HRNet/HRNet-Image-Classification + +Original header: + Copyright (c) Microsoft + Licensed under the MIT License. + Written by Bin Xiao (Bin.Xiao@microsoft.com) + Modified by Ke Sun (sunk@mail.ustc.edu.cn) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging +import functools + +import numpy as np + +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F + +from .registry import register_model +from .helpers import load_pretrained +from .helpers import load_pretrained +from .adaptive_avgmax_pool import SelectAdaptivePool2d +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'hrnet_w18_small': _cfg(url=''), + 'hrnet_w18_small_v2': _cfg(url=''), + 'hrnet_w18': _cfg(url=''), + 'hrnet_w30': _cfg(url=''), + 'hrnet_w32': _cfg(url=''), + 'hrnet_w40': _cfg(url=''), + 'hrnet_w44': _cfg(url=''), + 'hrnet_w48': _cfg(url=''), +} + +cfg_cls_hrnet_w18_small = dict( + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(1,), + NUM_CHANNELS=(32,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(16, 32), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=1, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(16, 32, 64), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=1, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(16, 32, 64, 128), + FUSE_METHOD='SUM', + ), +) + + +cfg_cls_hrnet_w18_small_v2 = dict( + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(2,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=3, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=2, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), +) + +cfg_cls_hrnet_w18 = dict( + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), +) + + +cfg_cls_hrnet_w30 = dict( + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(30, 60), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(30, 60, 120), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(30, 60, 120, 240), + FUSE_METHOD='SUM', + ), +) + + +cfg_cls_hrnet_w32 = dict( + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(32, 64), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(32, 64, 128), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(32, 64, 128, 256), + FUSE_METHOD='SUM', + ), +) + +cfg_cls_hrnet_w40 = dict( + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(40, 80), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(40, 80, 160), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(40, 80, 160, 320), + FUSE_METHOD='SUM', + ), +) + + +cfg_cls_hrnet_w44 = dict( + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(44, 88), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(44, 88, 176), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(44, 88, 176, 352), + FUSE_METHOD='SUM', + ), +) + + +cfg_cls_hrnet_w48 = dict( + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(48, 96), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(48, 96, 192), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(48, 96, 192, 384), + FUSE_METHOD='SUM', + ), +) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d( + planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM), + nn.Upsample(scale_factor=2 ** (j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM), + nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg'): + super(HighResolutionNet, self).__init__() + + self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion * num_channels + + self.stage2_cfg = cfg['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) + + # Classification Head + self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels) + + self.classifier = nn.Linear(2048, num_classes) + + self.init_weights() + + def _make_head(self, pre_stage_channels): + head_block = Bottleneck + head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_modules.append( + self._make_layer(head_block, channels, head_channels[i], 1, stride=1)) + incre_modules = nn.ModuleList(incre_modules) + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels) - 1): + in_channels = head_channels[i] * head_block.expansion + out_channels = head_channels[i + 1] * head_block.expansion + downsamp_module = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=head_channels[3] * head_block.expansion, + out_channels=2048, kernel_size=1, stride=1, padding=0 + ), + nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append(HighResolutionModule( + num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def init_weights(self, pretrained='', ): + logger.info('=> init weights from normal distribution') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + # Classification Head + y = self.incre_modules[0](y_list[0]) + for i in range(len(self.downsamp_modules)): + y = self.incre_modules[i + 1](y_list[i + 1]) + self.downsamp_modules[i](y) + + y = self.final_layer(y) + + if torch._C._get_tracing_state(): + y = y.flatten(start_dim=2).mean(dim=2) + else: + y = F.avg_pool2d(y, kernel_size=y.size()[2:]).view(y.size(0), -1) + + y = self.classifier(y) + + return y + + + +@register_model +def hrnet_w18_small(pretrained=True, **kwargs): + default_cfg = default_cfgs['hrnet_w18_small'] + model = HighResolutionNet(cfg_cls_hrnet_w18_small, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def hrnet_w18_small_v2(pretrained=True, **kwargs): + default_cfg = default_cfgs['hrnet_w18_small_v2'] + model = HighResolutionNet(cfg_cls_hrnet_w18_small_v2, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3)) + return model + +@register_model +def hrnet_w18(pretrained=True, **kwargs): + default_cfg = default_cfgs['hrnet_w18'] + model = HighResolutionNet(cfg_cls_hrnet_w18, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def hrnet_w30(pretrained=True, **kwargs): + default_cfg = default_cfgs['hrnet_w30'] + model = HighResolutionNet(cfg_cls_hrnet_w30, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3)) + return model + +@register_model +def hrnet_w32(pretrained=True, **kwargs): + default_cfg = default_cfgs['hrnet_w32'] + model = HighResolutionNet(cfg_cls_hrnet_w32, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3)) + return model + +@register_model +def hrnet_w40(pretrained=True, **kwargs): + default_cfg = default_cfgs['hrnet_w40'] + model = HighResolutionNet(cfg_cls_hrnet_w40, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def hrnet_w44(pretrained=True, **kwargs): + default_cfg = default_cfgs['hrnet_w44'] + model = HighResolutionNet(cfg_cls_hrnet_w44, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def hrnet_w48(pretrained=True, **kwargs): + default_cfg = default_cfgs['hrnet_w48'] + model = HighResolutionNet(cfg_cls_hrnet_w48, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3)) + return model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py new file mode 100644 index 00000000..a89adea4 --- /dev/null +++ b/timm/models/mobilenetv3.py @@ -0,0 +1,469 @@ + +""" MobileNet V3 + +A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. + +Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 + +Hacked together by Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .efficientnet_builder import * +from .activations import HardSwish, hard_sigmoid +from .registry import register_model +from .helpers import load_pretrained +from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .conv2d_layers import select_conv2d +from .feature_hooks import FeatureHooks +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD + +__all__ = ['MobileNetV3'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'mobilenetv3_large_075': _cfg(url=''), + 'mobilenetv3_large_100': _cfg(url=''), + 'mobilenetv3_small_075': _cfg(url=''), + 'mobilenetv3_small_100': _cfg(url=''), + 'mobilenetv3_rw': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + interpolation='bicubic'), + 'tf_mobilenetv3_large_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_minimal_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_100': _cfg( + url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_minimal_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), +} + +_DEBUG = False + + +class MobileNetV3(nn.Module): + """ MobiletNet-V3 + + Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific + 'efficient head', where global pooling is done before the head convolution without a final batch-norm + layer before the classifier. + + Paper: https://arxiv.org/abs/1905.02244 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, + channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + global_pool='avg', weight_init='goog'): + super(MobileNetV3, self).__init__() + + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + self._in_chs = in_chans + + # Stem + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.feature_info = builder.features + self._in_chs = builder.in_chs + + # Head + Pooling + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.act2 = act_layer(inplace=True) + + # Classifier + self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes) + + for m in self.modules(): + if weight_init == 'goog': + efficientnet_init_goog(m) + else: + efficientnet_init_default(m) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([self.global_pool, self.conv_head, self.act2]) + layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.num_classes = num_classes + del self.classifier + if num_classes: + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), num_classes) + else: + self.classifier = None + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +class MobileNetV3Features(nn.Module): + """ MobileNetV3 Feature Extractor + + A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation + and object detection models. + """ + + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', + in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='', + act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): + super(MobileNetV3Features, self).__init__() + norm_kwargs = norm_kwargs or {} + + # TODO only create stages needed, currently all stages are created regardless of out_indices + num_stages = max(out_indices) + 1 + + self.out_indices = out_indices + self.drop_rate = drop_rate + self._in_chs = in_chans + + # Stem + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.feature_info = builder.features # builder provides info about feature channels for each block + self._in_chs = builder.in_chs + + for m in self.modules(): + if weight_init == 'goog': + efficientnet_init_goog(m) + else: + efficientnet_init_default(m) + + if _DEBUG: + for k, v in self.feature_info.items(): + print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) + + # Register feature extraction hooks with FeatureHooks helper + hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward' + hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices] + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + + def feature_channels(self, idx=None): + """ Feature Channel Shortcut + Returns feature channel count for each output index if idx == None. If idx is an integer, will + return feature channel count for that feature block index (independent of out_indices setting). + """ + if isinstance(idx, int): + return self.feature_info[idx]['num_chs'] + return [self.feature_info[i]['num_chs'] for i in self.out_indices] + + def forward(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + self.blocks(x) + return self.feature_hooks.get_output(x.device) + + +def _create_model(model_kwargs, default_cfg, pretrained=False): + if model_kwargs.pop('features_only', False): + load_strict = False + model_kwargs.pop('num_classes', 0) + model_kwargs.pop('num_features', 0) + model_kwargs.pop('head_conv', None) + model_class = MobileNetV3Features + else: + load_strict = True + model_class = MobileNetV3 + + model = model_class(**model_kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=model_kwargs.get('num_classes', 0), + in_chans=model_kwargs.get('in_chans', 3), + strict=load_strict) + return model + + +def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=HardSwish, + se_kwargs=dict(gate_fn=hard_sigmoid, reduce_mid=True, divisor=1), + **kwargs, + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = nn.ReLU + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = HardSwish + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = nn.ReLU + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = HardSwish + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=act_layer, + se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), + **kwargs, + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +@register_model +def mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_100(pretrained=False, **kwargs): + print(kwargs) + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_rw(pretrained=False, **kwargs): + """ MobileNet V3 """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model diff --git a/timm/models/resnet.py b/timm/models/resnet.py index bedd303d..c7d80dba 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -103,7 +103,7 @@ class SEModule(nn.Module): def __init__(self, channels, reduction_channels): super(SEModule, self).__init__() - #self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d( channels, reduction_channels, kernel_size=1, padding=0, bias=True) self.relu = nn.ReLU(inplace=True) @@ -111,8 +111,7 @@ class SEModule(nn.Module): reduction_channels, channels, kernel_size=1, padding=0, bias=True) def forward(self, x): - #x_se = self.avg_pool(x) - x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) + x_se = self.avg_pool(x) x_se = self.fc1(x_se) x_se = self.relu(x_se) x_se = self.fc2(x_se) @@ -287,7 +286,8 @@ class ResNet(nn.Module): cardinality=1, base_width=64, stem_width=64, deep_stem=False, block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', - zero_init_last_bn=True, block_args=dict()): + zero_init_last_bn=True, block_args=None): + block_args = block_args or dict() self.num_classes = num_classes self.inplanes = stem_width * 2 if deep_stem else 64 self.cardinality = cardinality diff --git a/timm/models/senet.py b/timm/models/senet.py index 7ec1c453..0fbcfb86 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -68,7 +68,7 @@ class SEModule(nn.Module): def __init__(self, channels, reduction): super(SEModule, self).__init__() - #self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d( channels, channels // reduction, kernel_size=1, padding=0) self.relu = nn.ReLU(inplace=True) @@ -78,8 +78,7 @@ class SEModule(nn.Module): def forward(self, x): module_input = x - #x = self.avg_pool(x) - x = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) + x = self.avg_pool(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x)