Add MixNet (https://arxiv.org/abs/1907.09595) with pretrained weights converted from Tensorflow impl

* refactor 'same' convolution and add helper to use MixedConv2d when needed
* improve performance of 'same' padding for cases that can be handled statically
* add support for extra exp, pw, and dw kernel specs with grouping support to decoder/string defs for MixNet
* shuffle some args for a bit more consistency, a little less clutter overall in gen_efficientnet.py
pull/23/head
Ross Wightman 6 years ago
parent 7a92caa560
commit dfa9298b4e

@ -31,8 +31,9 @@ I've included a few of my favourite models, but this is not an exhaustive collec
* PNasNet & NASNet-A (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch))
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the efficient models that utilize similar DepthwiseSeparable and InvertedResidual blocks
* EfficientNet (B0-B5) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights
* MixNet (https://arxiv.org/abs/1907.09595) -- validated, compat with TF weights
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
* MobileNet-V1 (https://arxiv.org/abs/1704.04861)
* MobileNet-V2 (https://arxiv.org/abs/1801.04381)
@ -40,7 +41,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
* ChamNet (https://arxiv.org/abs/1812.08934) -- specific arch details hard to find, currently an educated guess
* FBNet-C (https://arxiv.org/abs/1812.03443) -- TODO A/B variants
* Single-Path NAS (https://arxiv.org/abs/1904.02877) -- pixel1 variant
Use the `--model` arg to specify model for train, validation, inference scripts. Match the all lowercase
creation fn for the model you'd like.
@ -118,11 +119,17 @@ I've leveraged the training scripts in this repository to train a few of the mod
| gluon_resnext50_32x4d | 79.356 (20.644) | 94.424 (5.576) | 25.03 | bicubic | |
| gluon_resnet101_v1b | 79.304 (20.696) | 94.524 (5.476) | 44.55 | bicubic | |
| gluon_resnet50_v1d | 79.074 (20.926) | 94.476 (5.524) | 25.58 | bicubic | |
| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
| gluon_resnet50_v1s | 78.712 (21.288) | 94.242 (5.758) | 25.68 | bicubic | |
| gluon_resnet50_v1c | 78.010 (21.990) | 93.988 (6.012) | 25.58 | bicubic | |
| gluon_resnet50_v1b | 77.578 (22.422) | 93.718 (6.282) | 25.56 | bicubic | |
| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
| tf_efficientnet_b0 *tfp | 76.828 (23.172) | 93.226 (6.774) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
| tf_efficientnet_b0 | 76.528 (23.472) | 93.010 (6.990) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
| gluon_resnet34_v1b | 74.580 (25.420) | 91.988 (8.012) | 21.80 | bicubic | |
| gluon_resnet18_v1b | 70.830 (29.170) | 89.756 (10.244) | 11.69 | bicubic | |

@ -112,7 +112,8 @@ def create_loader(
if tf_preprocessing and use_prefetcher:
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(is_training=is_training, size=img_size)
transform = TfPreprocessTransform(
is_training=is_training, size=img_size, interpolation=interpolation)
else:
if is_training:
transform = transforms_imagenet_train(

@ -83,7 +83,7 @@ def _at_least_x_are_equal(a, b, x):
return tf.greater_equal(tf.reduce_sum(match), x)
def _decode_and_random_crop(image_bytes, image_size):
def _decode_and_random_crop(image_bytes, image_size, resize_method):
"""Make a random crop of image_size."""
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
image = distorted_bounding_box_crop(
@ -100,13 +100,12 @@ def _decode_and_random_crop(image_bytes, image_size):
image = tf.cond(
bad,
lambda: _decode_and_center_crop(image_bytes, image_size),
lambda: tf.image.resize_bicubic([image], # pylint: disable=g-long-lambda
[image_size, image_size])[0])
lambda: tf.image.resize([image], [image_size, image_size], resize_method)[0])
return image
def _decode_and_center_crop(image_bytes, image_size):
def _decode_and_center_crop(image_bytes, image_size, resize_method):
"""Crops to center of image with padding then scales image_size."""
shape = tf.image.extract_jpeg_shape(image_bytes)
image_height = shape[0]
@ -122,7 +121,7 @@ def _decode_and_center_crop(image_bytes, image_size):
crop_window = tf.stack([offset_height, offset_width,
padded_center_crop_size, padded_center_crop_size])
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
image = tf.image.resize_bicubic([image], [image_size, image_size])[0]
image = tf.image.resize([image], [image_size, image_size], resize_method)[0]
return image
@ -133,18 +132,20 @@ def _flip(image):
return image
def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'):
"""Preprocesses the given image for evaluation.
Args:
image_bytes: `Tensor` representing an image binary of arbitrary size.
use_bfloat16: `bool` for whether to use bfloat16.
image_size: image size.
interpolation: image interpolation method
Returns:
A preprocessed image `Tensor`.
"""
image = _decode_and_random_crop(image_bytes, image_size)
resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
image = _decode_and_random_crop(image_bytes, image_size, resize_method)
image = _flip(image)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.convert_image_dtype(
@ -152,18 +153,20 @@ def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
return image
def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'):
"""Preprocesses the given image for evaluation.
Args:
image_bytes: `Tensor` representing an image binary of arbitrary size.
use_bfloat16: `bool` for whether to use bfloat16.
image_size: image size.
interpolation: image interpolation method
Returns:
A preprocessed image `Tensor`.
"""
image = _decode_and_center_crop(image_bytes, image_size)
resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
image = _decode_and_center_crop(image_bytes, image_size, resize_method)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.convert_image_dtype(
image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
@ -173,7 +176,8 @@ def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
def preprocess_image(image_bytes,
is_training=False,
use_bfloat16=False,
image_size=IMAGE_SIZE):
image_size=IMAGE_SIZE,
interpolation='bicubic'):
"""Preprocesses the given image.
Args:
@ -181,21 +185,23 @@ def preprocess_image(image_bytes,
is_training: `bool` for whether the preprocessing is for training.
use_bfloat16: `bool` for whether to use bfloat16.
image_size: image size.
interpolation: image interpolation method
Returns:
A preprocessed image `Tensor` with value range of [0, 255].
"""
if is_training:
return preprocess_for_train(image_bytes, use_bfloat16, image_size)
return preprocess_for_train(image_bytes, use_bfloat16, image_size, interpolation)
else:
return preprocess_for_eval(image_bytes, use_bfloat16, image_size)
return preprocess_for_eval(image_bytes, use_bfloat16, image_size, interpolation)
class TfPreprocessTransform:
def __init__(self, is_training=False, size=224):
def __init__(self, is_training=False, size=224, interpolation='bicubic'):
self.is_training = is_training
self.size = size[0] if isinstance(size, tuple) else size
self.interpolation = interpolation
self._image_bytes = None
self.process_image = self._build_tf_graph()
self.sess = None
@ -206,7 +212,8 @@ class TfPreprocessTransform:
shape=[],
dtype=tf.string,
)
img = preprocess_image(self._image_bytes, self.is_training, False, self.size)
img = preprocess_image(
self._image_bytes, self.is_training, False, self.size, self.interpolation)
return img
def __call__(self, image_bytes):

@ -0,0 +1,120 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
def _get_padding(kernel_size, stride=1, dilation=1, **_):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
def _calc_same_pad(i, k, s, d):
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
def _split_channels(num_chan, num_groups):
split = [num_chan // num_groups for _ in range(num_groups)]
split[0] += num_chan - sum(split)
return split
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation,
groups, bias)
def forward(self, x):
ih, iw = x.size()[-2:]
kh, kw = self.weight.size()[-2:]
pad_h = _calc_same_pad(ih, kh, self.stride[0], self.dilation[0])
pad_w = _calc_same_pad(iw, kw, self.stride[1], self.dilation[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
return F.conv2d(x, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False)
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if _is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = _get_padding(kernel_size, **kwargs)
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
else:
# dynamic padding
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
elif padding == 'valid':
# 'VALID' padding, same as padding=0
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs)
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = _get_padding(kernel_size, **kwargs)
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
else:
# padding was specified as a number or pair
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
class MixedConv2d(nn.Module):
""" Mixed Grouped Convolution
Based on MDConv and GroupedConv in MixNet impl:
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilated=False, depthwise=False, **kwargs):
super(MixedConv2d, self).__init__()
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
num_groups = len(kernel_size)
in_splits = _split_channels(in_channels, num_groups)
out_splits = _split_channels(out_channels, num_groups)
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
d = 1
# FIXME make compat with non-square kernel/dilations/strides
if stride == 1 and dilated:
d, k = (k - 1) // 2, 3
conv_groups = out_ch if depthwise else 1
# use add_module to keep key space clean
self.add_module(
str(idx),
conv2d_pad(
in_ch, out_ch, k, stride=stride,
padding=padding, dilation=d, groups=conv_groups, **kwargs)
)
self.splits = in_splits
def forward(self, x):
x_split = torch.split(x, self.splits, 1)
x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
x = torch.cat(x_out, 1)
return x
# helper method
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
assert 'groups' not in kwargs # only use 'depthwise' bool arg
if isinstance(kernel_size, list):
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
else:
depthwise = kwargs.pop('depthwise', False)
groups = out_chs if depthwise else 1
return conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)

@ -1,39 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation,
groups, bias)
def forward(self, x):
ih, iw = x.size()[-2:]
kh, kw = self.weight.size()[-2:]
oh = math.ceil(ih / self.stride[0])
ow = math.ceil(iw / self.stride[1])
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
return F.conv2d(x, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
# helper method
def sconv2d(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', 0)
if isinstance(padding, str):
if padding.lower() == 'same':
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
# 'valid'
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)

@ -1,13 +1,13 @@
""" Generic EfficientNets
A generic class with building blocks to support a variety of models with efficient architectures:
* EfficientNet (B0-B4 in code right now, work in progress, still verifying)
* MNasNet B1, A1 (SE), Small
* MobileNet V1, V2, and V3 (work in progress)
* EfficientNet (B0-B5)
* MixNet (Small, Medium, and Large)
* MnasNet B1, A1 (SE), Small
* MobileNet V1, V2, and V3
* FBNet-C (TODO A & B)
* ChamNet (TODO still guessing at architecture definition)
* Single-Path NAS Pixel1
* ShuffleNetV2 (TODO add IR shuffle block)
* And likely more...
TODO not all combinations and variations have been tested. Currently working on training hyper-params...
@ -27,7 +27,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .conv2d_same import sconv2d
from .conv2d_helpers import select_conv2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -37,7 +37,7 @@ __all__ = ['GenEfficientNet']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs
@ -48,14 +48,12 @@ default_cfgs = {
'mnasnet_050': _cfg(url=''),
'mnasnet_075': _cfg(url=''),
'mnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
interpolation='bicubic'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'),
'mnasnet_140': _cfg(url=''),
'semnasnet_050': _cfg(url=''),
'semnasnet_075': _cfg(url=''),
'semnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth',
interpolation='bicubic'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'),
'semnasnet_140': _cfg(url=''),
'mnasnet_small': _cfg(url=''),
'mobilenetv1_100': _cfg(url=''),
@ -63,23 +61,23 @@ default_cfgs = {
'mobilenetv3_050': _cfg(url=''),
'mobilenetv3_075': _cfg(url=''),
'mobilenetv3_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
interpolation='bicubic'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth'),
'chamnetv1_100': _cfg(url=''),
'chamnetv2_100': _cfg(url=''),
'fbnetc_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
interpolation='bilinear'),
'spnasnet_100': _cfg(
url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet_100-048bc3f4.pth?dl=1'),
url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet_100-048bc3f4.pth?dl=1',
interpolation='bilinear'),
'efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth',
interpolation='bicubic'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'),
'efficientnet_b1': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
input_size=(3, 240, 240), pool_size=(8, 8), interpolation='bicubic', crop_pct=0.882),
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'efficientnet_b2': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth',
input_size=(3, 260, 260), pool_size=(9, 9), interpolation='bicubic', crop_pct=0.890),
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'efficientnet_b3': _cfg(
url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'efficientnet_b4': _cfg(
@ -88,22 +86,31 @@ default_cfgs = {
url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth',
input_size=(3, 224, 224), interpolation='bicubic'),
input_size=(3, 224, 224)),
'tf_efficientnet_b1': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth',
input_size=(3, 240, 240), pool_size=(8, 8), interpolation='bicubic', crop_pct=0.882),
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b2': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth',
input_size=(3, 260, 260), pool_size=(9, 9), interpolation='bicubic', crop_pct=0.890),
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'tf_efficientnet_b3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth',
input_size=(3, 300, 300), pool_size=(10, 10), interpolation='bicubic', crop_pct=0.904),
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b4': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth',
input_size=(3, 380, 380), pool_size=(12, 12), interpolation='bicubic', crop_pct=0.922),
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b5': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth',
input_size=(3, 456, 456), pool_size=(15, 15), interpolation='bicubic', crop_pct=0.934)
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'mixnet_s': _cfg(url=''),
'mixnet_m': _cfg(url=''),
'mixnet_l': _cfg(url=''),
'tf_mixnet_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'),
'tf_mixnet_m': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth'),
'tf_mixnet_l': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'),
}
@ -151,6 +158,13 @@ def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
return new_channels
def _parse_ksize(ss):
if ss.isdigit():
return int(ss)
else:
return [int(k) for k in ss.split('.')]
def _decode_block_str(block_str, depth_multiplier=1.0):
""" Decode block definition string
@ -168,7 +182,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
a - activation fn ('re', 'r6', or 'hs')
n - activation fn ('re', 'r6', 'hs', or 'sw')
Args:
block_str: a string representation of block arguments.
Returns:
@ -184,7 +198,9 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
noskip = False
for op in ops:
# string options being checked on individual basis, combine if they grow
if op.startswith('a'):
if op == 'noskip':
noskip = True
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
@ -194,11 +210,11 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
value = F.relu6
elif v == 'hs':
value = hard_swish
elif v == 'sw':
value = swish
else:
continue
options[key] = value
elif op == 'noskip':
noskip = True
else:
# all numeric options
splits = re.split(r'(\d.*)', op)
@ -207,14 +223,18 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
options[key] = value
# if act_fn is None, the model default (passed to model init) will be used
act_fn = options['a'] if 'a' in options else None
act_fn = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
@ -222,20 +242,17 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
act_fn=act_fn,
noskip=noskip,
)
if 'g' in options:
block_args['pw_group'] = options['g']
if options['g'] > 1:
block_args['shuffle_type'] = 'mid'
elif block_type == 'ds' or block_type == 'dsa':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
noskip=block_type == 'dsa' or noskip,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
)
elif block_type == 'cn':
block_args = dict(
@ -254,15 +271,6 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
return [deepcopy(block_args) for _ in range(num_repeat)]
def _get_padding(kernel_size, stride, dilation=1):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
def _padding_arg(default, padding_same=False):
return 'SAME' if padding_same else default
def _decode_arch_args(string_list):
block_args = []
for block_str in string_list:
@ -316,20 +324,18 @@ class _BlockBuilder:
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
"""
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
drop_connect_rate=0., act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False,
bn_args=_BN_ARGS_PT, padding_same=False,
verbose=False):
pad_type='', act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False,
bn_args=_BN_ARGS_PT, drop_connect_rate=0., verbose=False):
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.drop_connect_rate = drop_connect_rate
self.pad_type = pad_type
self.act_fn = act_fn
self.se_gate_fn = se_gate_fn
self.se_reduce_mid = se_reduce_mid
self.bn_args = bn_args
self.padding_same = padding_same
self.drop_connect_rate = drop_connect_rate
self.verbose = verbose
# updated during build
@ -345,7 +351,7 @@ class _BlockBuilder:
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
ba['bn_args'] = self.bn_args
ba['padding_same'] = self.padding_same
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
assert ba['act_fn'] is not None
@ -493,16 +499,11 @@ class SqueezeExcite(nn.Module):
class ConvBnAct(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, act_fn=F.relu,
bn_args=_BN_ARGS_PT, padding_same=False):
stride=1, pad_type='', act_fn=F.relu, bn_args=_BN_ARGS_PT):
super(ConvBnAct, self).__init__()
assert stride in [1, 2]
self.act_fn = act_fn
padding = _padding_arg(_get_padding(kernel_size, stride), padding_same)
self.conv = sconv2d(
in_chs, out_chs, kernel_size,
stride=stride, padding=padding, bias=False)
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type)
self.bn1 = nn.BatchNorm2d(out_chs, **bn_args)
def forward(self, x):
@ -517,10 +518,11 @@ class DepthwiseSeparableConv(nn.Module):
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
factor of 1.0. This is an alternative to having a IR with optional first pw conv.
"""
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_fn=F.relu, noskip=False,
pw_kernel_size=1, pw_act=False,
se_ratio=0., se_gate_fn=sigmoid,
bn_args=_BN_ARGS_PT, padding_same=False, drop_connect_rate=0.):
bn_args=_BN_ARGS_PT, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
assert stride in [1, 2]
self.has_se = se_ratio is not None and se_ratio > 0.
@ -528,12 +530,9 @@ class DepthwiseSeparableConv(nn.Module):
self.has_pw_act = pw_act # activation after point-wise conv
self.act_fn = act_fn
self.drop_connect_rate = drop_connect_rate
dw_padding = _padding_arg(kernel_size // 2, padding_same)
pw_padding = _padding_arg(0, padding_same)
self.conv_dw = sconv2d(
in_chs, in_chs, kernel_size,
stride=stride, padding=dw_padding, groups=in_chs, bias=False)
self.conv_dw = select_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
self.bn1 = nn.BatchNorm2d(in_chs, **bn_args)
# Squeeze-and-excitation
@ -541,7 +540,7 @@ class DepthwiseSeparableConv(nn.Module):
self.se = SqueezeExcite(
in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=False)
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = nn.BatchNorm2d(out_chs, **bn_args)
def forward(self, x):
@ -569,31 +568,29 @@ class DepthwiseSeparableConv(nn.Module):
class InvertedResidual(nn.Module):
""" Inverted residual block w/ optional SE"""
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False,
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_fn=F.relu, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid,
shuffle_type=None, pw_group=1,
bn_args=_BN_ARGS_PT, padding_same=False, drop_connect_rate=0.):
shuffle_type=None, bn_args=_BN_ARGS_PT, drop_connect_rate=0.):
super(InvertedResidual, self).__init__()
mid_chs = int(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.act_fn = act_fn
self.drop_connect_rate = drop_connect_rate
dw_padding = _padding_arg(kernel_size // 2, padding_same)
pw_padding = _padding_arg(0, padding_same)
# Point-wise expansion
self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=False)
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args)
self.shuffle_type = shuffle_type
if shuffle_type is not None:
self.shuffle = ChannelShuffle(pw_group)
if shuffle_type is not None and isinstance(exp_kernel_size, list):
self.shuffle = ChannelShuffle(len(exp_kernel_size))
# Depth-wise convolution
self.conv_dw = sconv2d(
mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=False)
self.conv_dw = select_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
self.bn2 = nn.BatchNorm2d(mid_chs, **bn_args)
# Squeeze-and-excitation
@ -603,7 +600,7 @@ class InvertedResidual(nn.Module):
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
# Point-wise linear projection
self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=False)
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn3 = nn.BatchNorm2d(out_chs, **bn_args)
def forward(self, x):
@ -649,18 +646,19 @@ class GenEfficientNet(nn.Module):
* MobileNet-V1
* MobileNet-V2
* MobileNet-V3
* MNASNet A1, B1, and small
* MnasNet A1, B1, and small
* FBNet A, B, and C
* ChamNet (arch details are murky)
* Single-Path NAS Pixel1
* EfficientNetB0-B4 (rest easy to add)
* EfficientNet B0-B5
* MixNet S, M, L
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
drop_rate=0., drop_connect_rate=0., act_fn=F.relu,
pad_type='', act_fn=F.relu, drop_rate=0., drop_connect_rate=0.,
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=_BN_ARGS_PT,
global_pool='avg', head_conv='default', weight_init='goog', padding_same=False):
global_pool='avg', head_conv='default', weight_init='goog'):
super(GenEfficientNet, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -668,16 +666,14 @@ class GenEfficientNet(nn.Module):
self.num_features = num_features
stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = sconv2d(
in_chans, stem_size, 3,
padding=_padding_arg(1, padding_same), stride=2, bias=False)
self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = nn.BatchNorm2d(stem_size, **bn_args)
in_chs = stem_size
builder = _BlockBuilder(
channel_multiplier, channel_divisor, channel_min,
drop_connect_rate, act_fn, se_gate_fn, se_reduce_mid,
bn_args, padding_same, verbose=_DEBUG)
pad_type, act_fn, se_gate_fn, se_reduce_mid,
bn_args, drop_connect_rate, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(in_chs, block_args))
in_chs = builder.in_chs
@ -687,9 +683,7 @@ class GenEfficientNet(nn.Module):
assert in_chs == self.num_features
else:
self.efficient_head = head_conv == 'efficient'
self.conv_head = sconv2d(
in_chs, self.num_features, 1,
padding=_padding_arg(0, padding_same), bias=False)
self.conv_head = select_conv2d(in_chs, self.num_features, 1, padding=pad_type)
self.bn2 = None if self.efficient_head else nn.BatchNorm2d(self.num_features, **bn_args)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -919,11 +913,11 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_are_noskip'], # relu
['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_are', 'ir_r1_k3_s1_e3_c24_are'], # relu
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_are'], # relu
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# stage 4, 14x14in
@ -1129,6 +1123,78 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
return model
def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
"""Creates a MixNet Small model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
Paper: https://arxiv.org/abs/1907.09595
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16'], # relu
# stage 1, 112x112 in
['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu
# stage 2, 56x56 in
['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
# stage 3, 28x28 in
['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish
# stage 4, 14x14in
['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
# stage 5, 14x14in
['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
# 7x7
]
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
stem_size=16,
num_features=1536,
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs),
act_fn=F.relu,
**kwargs
)
return model
def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
"""Creates a MixNet Medium-Large model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
Paper: https://arxiv.org/abs/1907.09595
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c24'], # relu
# stage 1, 112x112 in
['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu
# stage 2, 56x56 in
['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
# stage 3, 28x28 in
['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish
# stage 4, 14x14in
['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
# stage 5, 14x14in
['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
# 7x7
]
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
stem_size=24,
num_features=1536,
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs),
act_fn=F.relu,
**kwargs
)
return model
@register_model
def mnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" MNASNet B1, depth multiplier of 0.5. """
@ -1440,7 +1506,7 @@ def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B0. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_b0']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['padding_same'] = True
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
channel_multiplier=1.0, depth_multiplier=1.0,
num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -1455,7 +1521,7 @@ def tf_efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B1. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_b1']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['padding_same'] = True
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
channel_multiplier=1.0, depth_multiplier=1.1,
num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -1470,7 +1536,7 @@ def tf_efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B2. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_b2']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['padding_same'] = True
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
channel_multiplier=1.1, depth_multiplier=1.2,
num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -1485,7 +1551,7 @@ def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B3. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_b3']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['padding_same'] = True
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
channel_multiplier=1.2, depth_multiplier=1.4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -1500,7 +1566,7 @@ def tf_efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B4. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_b4']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['padding_same'] = True
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
channel_multiplier=1.4, depth_multiplier=1.8,
num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -1515,7 +1581,7 @@ def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B5. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_b5']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['padding_same'] = True
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
channel_multiplier=1.6, depth_multiplier=2.2,
num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -1525,5 +1591,89 @@ def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
return model
@register_model
def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Creates a MixNet Small model.
"""
default_cfg = default_cfgs['mixnet_m']
model = _gen_mixnet_s(
channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
#if pretrained:
# load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def mixnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Creates a MixNet Medium model.
"""
default_cfg = default_cfgs['mixnet_m']
model = _gen_mixnet_m(
channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
#if pretrained:
# load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Creates a MixNet Large model.
"""
default_cfg = default_cfgs['mixnet_l']
model = _gen_mixnet_m(
channel_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def tf_mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Creates a MixNet Small model. Tensorflow compatible variant
"""
default_cfg = default_cfgs['tf_mixnet_s']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mixnet_s(
channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def tf_mixnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Creates a MixNet Medium model. Tensorflow compatible variant
"""
default_cfg = default_cfgs['tf_mixnet_m']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mixnet_m(
channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def tf_mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Creates a MixNet Large model. Tensorflow compatible variant
"""
default_cfg = default_cfgs['tf_mixnet_l']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mixnet_m(
channel_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def gen_efficientnet_model_names():
return set(_models)

Loading…
Cancel
Save