diff --git a/README.md b/README.md index ff534700..9044f875 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,9 @@ I've included a few of my favourite models, but this is not an exhaustive collec * PNasNet & NASNet-A (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch)) * DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene) * DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107 -* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks +* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the efficient models that utilize similar DepthwiseSeparable and InvertedResidual blocks * EfficientNet (B0-B5) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights + * MixNet (https://arxiv.org/abs/1907.09595) -- validated, compat with TF weights * MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626) * MobileNet-V1 (https://arxiv.org/abs/1704.04861) * MobileNet-V2 (https://arxiv.org/abs/1801.04381) @@ -40,7 +41,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec * ChamNet (https://arxiv.org/abs/1812.08934) -- specific arch details hard to find, currently an educated guess * FBNet-C (https://arxiv.org/abs/1812.03443) -- TODO A/B variants * Single-Path NAS (https://arxiv.org/abs/1904.02877) -- pixel1 variant - + Use the `--model` arg to specify model for train, validation, inference scripts. Match the all lowercase creation fn for the model you'd like. @@ -118,11 +119,17 @@ I've leveraged the training scripts in this repository to train a few of the mod | gluon_resnext50_32x4d | 79.356 (20.644) | 94.424 (5.576) | 25.03 | bicubic | | | gluon_resnet101_v1b | 79.304 (20.696) | 94.524 (5.476) | 44.55 | bicubic | | | gluon_resnet50_v1d | 79.074 (20.926) | 94.476 (5.524) | 25.58 | bicubic | | +| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) | +| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) | | gluon_resnet50_v1s | 78.712 (21.288) | 94.242 (5.758) | 25.68 | bicubic | | | gluon_resnet50_v1c | 78.010 (21.990) | 93.988 (6.012) | 25.58 | bicubic | | | gluon_resnet50_v1b | 77.578 (22.422) | 93.718 (6.282) | 25.56 | bicubic | | +| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) | +| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) | | tf_efficientnet_b0 *tfp | 76.828 (23.172) | 93.226 (6.774) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) | | tf_efficientnet_b0 | 76.528 (23.472) | 93.010 (6.990) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) | +| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) | +| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) | | gluon_resnet34_v1b | 74.580 (25.420) | 91.988 (8.012) | 21.80 | bicubic | | | gluon_resnet18_v1b | 70.830 (29.170) | 89.756 (10.244) | 11.69 | bicubic | | diff --git a/timm/data/loader.py b/timm/data/loader.py index 6b6e2b39..1198d5e5 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -112,7 +112,8 @@ def create_loader( if tf_preprocessing and use_prefetcher: from timm.data.tf_preprocessing import TfPreprocessTransform - transform = TfPreprocessTransform(is_training=is_training, size=img_size) + transform = TfPreprocessTransform( + is_training=is_training, size=img_size, interpolation=interpolation) else: if is_training: transform = transforms_imagenet_train( diff --git a/timm/data/tf_preprocessing.py b/timm/data/tf_preprocessing.py index 9d0963d8..61dc78e3 100644 --- a/timm/data/tf_preprocessing.py +++ b/timm/data/tf_preprocessing.py @@ -83,7 +83,7 @@ def _at_least_x_are_equal(a, b, x): return tf.greater_equal(tf.reduce_sum(match), x) -def _decode_and_random_crop(image_bytes, image_size): +def _decode_and_random_crop(image_bytes, image_size, resize_method): """Make a random crop of image_size.""" bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) image = distorted_bounding_box_crop( @@ -100,13 +100,12 @@ def _decode_and_random_crop(image_bytes, image_size): image = tf.cond( bad, lambda: _decode_and_center_crop(image_bytes, image_size), - lambda: tf.image.resize_bicubic([image], # pylint: disable=g-long-lambda - [image_size, image_size])[0]) + lambda: tf.image.resize([image], [image_size, image_size], resize_method)[0]) return image -def _decode_and_center_crop(image_bytes, image_size): +def _decode_and_center_crop(image_bytes, image_size, resize_method): """Crops to center of image with padding then scales image_size.""" shape = tf.image.extract_jpeg_shape(image_bytes) image_height = shape[0] @@ -122,7 +121,7 @@ def _decode_and_center_crop(image_bytes, image_size): crop_window = tf.stack([offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]) image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) - image = tf.image.resize_bicubic([image], [image_size, image_size])[0] + image = tf.image.resize([image], [image_size, image_size], resize_method)[0] return image @@ -133,18 +132,20 @@ def _flip(image): return image -def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): +def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'): """Preprocesses the given image for evaluation. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. use_bfloat16: `bool` for whether to use bfloat16. image_size: image size. + interpolation: image interpolation method Returns: A preprocessed image `Tensor`. """ - image = _decode_and_random_crop(image_bytes, image_size) + resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR + image = _decode_and_random_crop(image_bytes, image_size, resize_method) image = _flip(image) image = tf.reshape(image, [image_size, image_size, 3]) image = tf.image.convert_image_dtype( @@ -152,18 +153,20 @@ def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): return image -def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): +def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'): """Preprocesses the given image for evaluation. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. use_bfloat16: `bool` for whether to use bfloat16. image_size: image size. + interpolation: image interpolation method Returns: A preprocessed image `Tensor`. """ - image = _decode_and_center_crop(image_bytes, image_size) + resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR + image = _decode_and_center_crop(image_bytes, image_size, resize_method) image = tf.reshape(image, [image_size, image_size, 3]) image = tf.image.convert_image_dtype( image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) @@ -173,7 +176,8 @@ def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): def preprocess_image(image_bytes, is_training=False, use_bfloat16=False, - image_size=IMAGE_SIZE): + image_size=IMAGE_SIZE, + interpolation='bicubic'): """Preprocesses the given image. Args: @@ -181,21 +185,23 @@ def preprocess_image(image_bytes, is_training: `bool` for whether the preprocessing is for training. use_bfloat16: `bool` for whether to use bfloat16. image_size: image size. + interpolation: image interpolation method Returns: A preprocessed image `Tensor` with value range of [0, 255]. """ if is_training: - return preprocess_for_train(image_bytes, use_bfloat16, image_size) + return preprocess_for_train(image_bytes, use_bfloat16, image_size, interpolation) else: - return preprocess_for_eval(image_bytes, use_bfloat16, image_size) + return preprocess_for_eval(image_bytes, use_bfloat16, image_size, interpolation) class TfPreprocessTransform: - def __init__(self, is_training=False, size=224): + def __init__(self, is_training=False, size=224, interpolation='bicubic'): self.is_training = is_training self.size = size[0] if isinstance(size, tuple) else size + self.interpolation = interpolation self._image_bytes = None self.process_image = self._build_tf_graph() self.sess = None @@ -206,7 +212,8 @@ class TfPreprocessTransform: shape=[], dtype=tf.string, ) - img = preprocess_image(self._image_bytes, self.is_training, False, self.size) + img = preprocess_image( + self._image_bytes, self.is_training, False, self.size, self.interpolation) return img def __call__(self, image_bytes): diff --git a/timm/models/conv2d_helpers.py b/timm/models/conv2d_helpers.py new file mode 100644 index 00000000..674eadca --- /dev/null +++ b/timm/models/conv2d_helpers.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def _get_padding(kernel_size, stride=1, dilation=1, **_): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def _calc_same_pad(i, k, s, d): + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, + groups, bias) + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + pad_h = _calc_same_pad(ih, kh, self.stride[0], self.dilation[0]) + pad_w = _calc_same_pad(iw, kw, self.stride[1], self.dilation[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2]) + return F.conv2d(x, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +def conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if _is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = _get_padding(kernel_size, **kwargs) + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + else: + # dynamic padding + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs) + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = _get_padding(kernel_size, **kwargs) + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + else: + # padding was specified as a number or pair + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + +class MixedConv2d(nn.Module): + """ Mixed Grouped Convolution + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilated=False, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + d = 1 + # FIXME make compat with non-square kernel/dilations/strides + if stride == 1 and dilated: + d, k = (k - 1) // 2, 3 + conv_groups = out_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=d, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x) for x, c in zip(x_split, self._modules.values())] + x = torch.cat(x_out, 1) + return x + + +# helper method +def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): + assert 'groups' not in kwargs # only use 'depthwise' bool arg + if isinstance(kernel_size, list): + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_chs if depthwise else 1 + return conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + diff --git a/timm/models/conv2d_same.py b/timm/models/conv2d_same.py deleted file mode 100644 index 6d0b9e09..00000000 --- a/timm/models/conv2d_same.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math - - -class Conv2dSame(nn.Conv2d): - """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions - """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True): - super(Conv2dSame, self).__init__( - in_channels, out_channels, kernel_size, stride, 0, dilation, - groups, bias) - - def forward(self, x): - ih, iw = x.size()[-2:] - kh, kw = self.weight.size()[-2:] - oh = math.ceil(ih / self.stride[0]) - ow = math.ceil(iw / self.stride[1]) - pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) - pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2]) - return F.conv2d(x, self.weight, self.bias, self.stride, - self.padding, self.dilation, self.groups) - - -# helper method -def sconv2d(in_chs, out_chs, kernel_size, **kwargs): - padding = kwargs.pop('padding', 0) - if isinstance(padding, str): - if padding.lower() == 'same': - return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) - else: - # 'valid' - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs) - else: - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index 6e54c70e..d0e6ab43 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -1,13 +1,13 @@ """ Generic EfficientNets A generic class with building blocks to support a variety of models with efficient architectures: -* EfficientNet (B0-B4 in code right now, work in progress, still verifying) -* MNasNet B1, A1 (SE), Small -* MobileNet V1, V2, and V3 (work in progress) +* EfficientNet (B0-B5) +* MixNet (Small, Medium, and Large) +* MnasNet B1, A1 (SE), Small +* MobileNet V1, V2, and V3 * FBNet-C (TODO A & B) * ChamNet (TODO still guessing at architecture definition) * Single-Path NAS Pixel1 -* ShuffleNetV2 (TODO add IR shuffle block) * And likely more... TODO not all combinations and variations have been tested. Currently working on training hyper-params... @@ -27,7 +27,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d -from .conv2d_same import sconv2d +from .conv2d_helpers import select_conv2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -37,7 +37,7 @@ __all__ = ['GenEfficientNet'] def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'conv_stem', 'classifier': 'classifier', **kwargs @@ -48,14 +48,12 @@ default_cfgs = { 'mnasnet_050': _cfg(url=''), 'mnasnet_075': _cfg(url=''), 'mnasnet_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth', - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'), 'mnasnet_140': _cfg(url=''), 'semnasnet_050': _cfg(url=''), 'semnasnet_075': _cfg(url=''), 'semnasnet_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth', - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'), 'semnasnet_140': _cfg(url=''), 'mnasnet_small': _cfg(url=''), 'mobilenetv1_100': _cfg(url=''), @@ -63,23 +61,23 @@ default_cfgs = { 'mobilenetv3_050': _cfg(url=''), 'mobilenetv3_075': _cfg(url=''), 'mobilenetv3_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth'), 'chamnetv1_100': _cfg(url=''), 'chamnetv2_100': _cfg(url=''), 'fbnetc_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + interpolation='bilinear'), 'spnasnet_100': _cfg( - url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet_100-048bc3f4.pth?dl=1'), + url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet_100-048bc3f4.pth?dl=1', + interpolation='bilinear'), 'efficientnet_b0': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth', - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'), 'efficientnet_b1': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', - input_size=(3, 240, 240), pool_size=(8, 8), interpolation='bicubic', crop_pct=0.882), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), 'efficientnet_b2': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth', - input_size=(3, 260, 260), pool_size=(9, 9), interpolation='bicubic', crop_pct=0.890), + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), 'efficientnet_b3': _cfg( url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), 'efficientnet_b4': _cfg( @@ -88,22 +86,31 @@ default_cfgs = { url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), 'tf_efficientnet_b0': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth', - input_size=(3, 224, 224), interpolation='bicubic'), + input_size=(3, 224, 224)), 'tf_efficientnet_b1': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth', - input_size=(3, 240, 240), pool_size=(8, 8), interpolation='bicubic', crop_pct=0.882), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), 'tf_efficientnet_b2': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth', - input_size=(3, 260, 260), pool_size=(9, 9), interpolation='bicubic', crop_pct=0.890), + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), 'tf_efficientnet_b3': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth', - input_size=(3, 300, 300), pool_size=(10, 10), interpolation='bicubic', crop_pct=0.904), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), 'tf_efficientnet_b4': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth', - input_size=(3, 380, 380), pool_size=(12, 12), interpolation='bicubic', crop_pct=0.922), + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), 'tf_efficientnet_b5': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth', - input_size=(3, 456, 456), pool_size=(15, 15), interpolation='bicubic', crop_pct=0.934) + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'mixnet_s': _cfg(url=''), + 'mixnet_m': _cfg(url=''), + 'mixnet_l': _cfg(url=''), + 'tf_mixnet_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'), + 'tf_mixnet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth'), + 'tf_mixnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), } @@ -151,6 +158,13 @@ def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): return new_channels +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + def _decode_block_str(block_str, depth_multiplier=1.0): """ Decode block definition string @@ -168,7 +182,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0): e - expansion ratio, c - output channels, se - squeeze/excitation ratio - a - activation fn ('re', 'r6', or 'hs') + n - activation fn ('re', 'r6', 'hs', or 'sw') Args: block_str: a string representation of block arguments. Returns: @@ -184,7 +198,9 @@ def _decode_block_str(block_str, depth_multiplier=1.0): noskip = False for op in ops: # string options being checked on individual basis, combine if they grow - if op.startswith('a'): + if op == 'noskip': + noskip = True + elif op.startswith('n'): # activation fn key = op[0] v = op[1:] @@ -194,11 +210,11 @@ def _decode_block_str(block_str, depth_multiplier=1.0): value = F.relu6 elif v == 'hs': value = hard_swish + elif v == 'sw': + value = swish else: continue options[key] = value - elif op == 'noskip': - noskip = True else: # all numeric options splits = re.split(r'(\d.*)', op) @@ -207,14 +223,18 @@ def _decode_block_str(block_str, depth_multiplier=1.0): options[key] = value # if act_fn is None, the model default (passed to model init) will be used - act_fn = options['a'] if 'a' in options else None + act_fn = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 num_repeat = int(options['r']) # each type of block has different valid arguments, fill accordingly if block_type == 'ir': block_args = dict( block_type=block_type, - kernel_size=int(options['k']), + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, out_chs=int(options['c']), exp_ratio=float(options['e']), se_ratio=float(options['se']) if 'se' in options else None, @@ -222,20 +242,17 @@ def _decode_block_str(block_str, depth_multiplier=1.0): act_fn=act_fn, noskip=noskip, ) - if 'g' in options: - block_args['pw_group'] = options['g'] - if options['g'] > 1: - block_args['shuffle_type'] = 'mid' elif block_type == 'ds' or block_type == 'dsa': block_args = dict( block_type=block_type, - kernel_size=int(options['k']), + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, out_chs=int(options['c']), se_ratio=float(options['se']) if 'se' in options else None, stride=int(options['s']), act_fn=act_fn, - noskip=block_type == 'dsa' or noskip, pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, ) elif block_type == 'cn': block_args = dict( @@ -254,15 +271,6 @@ def _decode_block_str(block_str, depth_multiplier=1.0): return [deepcopy(block_args) for _ in range(num_repeat)] -def _get_padding(kernel_size, stride, dilation=1): - padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 - return padding - - -def _padding_arg(default, padding_same=False): - return 'SAME' if padding_same else default - - def _decode_arch_args(string_list): block_args = [] for block_str in string_list: @@ -316,20 +324,18 @@ class _BlockBuilder: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py """ - def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - drop_connect_rate=0., act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False, - bn_args=_BN_ARGS_PT, padding_same=False, - verbose=False): + pad_type='', act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False, + bn_args=_BN_ARGS_PT, drop_connect_rate=0., verbose=False): self.channel_multiplier = channel_multiplier self.channel_divisor = channel_divisor self.channel_min = channel_min - self.drop_connect_rate = drop_connect_rate + self.pad_type = pad_type self.act_fn = act_fn self.se_gate_fn = se_gate_fn self.se_reduce_mid = se_reduce_mid self.bn_args = bn_args - self.padding_same = padding_same + self.drop_connect_rate = drop_connect_rate self.verbose = verbose # updated during build @@ -345,7 +351,7 @@ class _BlockBuilder: ba['in_chs'] = self.in_chs ba['out_chs'] = self._round_channels(ba['out_chs']) ba['bn_args'] = self.bn_args - ba['padding_same'] = self.padding_same + ba['pad_type'] = self.pad_type # block act fn overrides the model default ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn assert ba['act_fn'] is not None @@ -493,16 +499,11 @@ class SqueezeExcite(nn.Module): class ConvBnAct(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, - stride=1, act_fn=F.relu, - bn_args=_BN_ARGS_PT, padding_same=False): + stride=1, pad_type='', act_fn=F.relu, bn_args=_BN_ARGS_PT): super(ConvBnAct, self).__init__() assert stride in [1, 2] self.act_fn = act_fn - padding = _padding_arg(_get_padding(kernel_size, stride), padding_same) - - self.conv = sconv2d( - in_chs, out_chs, kernel_size, - stride=stride, padding=padding, bias=False) + self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) self.bn1 = nn.BatchNorm2d(out_chs, **bn_args) def forward(self, x): @@ -517,10 +518,11 @@ class DepthwiseSeparableConv(nn.Module): Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion factor of 1.0. This is an alternative to having a IR with optional first pw conv. """ - def __init__(self, in_chs, out_chs, kernel_size, - stride=1, act_fn=F.relu, noskip=False, pw_act=False, + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_fn=F.relu, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid, - bn_args=_BN_ARGS_PT, padding_same=False, drop_connect_rate=0.): + bn_args=_BN_ARGS_PT, drop_connect_rate=0.): super(DepthwiseSeparableConv, self).__init__() assert stride in [1, 2] self.has_se = se_ratio is not None and se_ratio > 0. @@ -528,12 +530,9 @@ class DepthwiseSeparableConv(nn.Module): self.has_pw_act = pw_act # activation after point-wise conv self.act_fn = act_fn self.drop_connect_rate = drop_connect_rate - dw_padding = _padding_arg(kernel_size // 2, padding_same) - pw_padding = _padding_arg(0, padding_same) - self.conv_dw = sconv2d( - in_chs, in_chs, kernel_size, - stride=stride, padding=dw_padding, groups=in_chs, bias=False) + self.conv_dw = select_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) self.bn1 = nn.BatchNorm2d(in_chs, **bn_args) # Squeeze-and-excitation @@ -541,7 +540,7 @@ class DepthwiseSeparableConv(nn.Module): self.se = SqueezeExcite( in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) - self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=False) + self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = nn.BatchNorm2d(out_chs, **bn_args) def forward(self, x): @@ -569,31 +568,29 @@ class DepthwiseSeparableConv(nn.Module): class InvertedResidual(nn.Module): """ Inverted residual block w/ optional SE""" - def __init__(self, in_chs, out_chs, kernel_size, - stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False, + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_fn=F.relu, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid, - shuffle_type=None, pw_group=1, - bn_args=_BN_ARGS_PT, padding_same=False, drop_connect_rate=0.): + shuffle_type=None, bn_args=_BN_ARGS_PT, drop_connect_rate=0.): super(InvertedResidual, self).__init__() mid_chs = int(in_chs * exp_ratio) self.has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.act_fn = act_fn self.drop_connect_rate = drop_connect_rate - dw_padding = _padding_arg(kernel_size // 2, padding_same) - pw_padding = _padding_arg(0, padding_same) # Point-wise expansion - self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=False) + self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args) self.shuffle_type = shuffle_type - if shuffle_type is not None: - self.shuffle = ChannelShuffle(pw_group) + if shuffle_type is not None and isinstance(exp_kernel_size, list): + self.shuffle = ChannelShuffle(len(exp_kernel_size)) # Depth-wise convolution - self.conv_dw = sconv2d( - mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=False) + self.conv_dw = select_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) self.bn2 = nn.BatchNorm2d(mid_chs, **bn_args) # Squeeze-and-excitation @@ -603,7 +600,7 @@ class InvertedResidual(nn.Module): mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) # Point-wise linear projection - self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=False) + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn3 = nn.BatchNorm2d(out_chs, **bn_args) def forward(self, x): @@ -649,18 +646,19 @@ class GenEfficientNet(nn.Module): * MobileNet-V1 * MobileNet-V2 * MobileNet-V3 - * MNASNet A1, B1, and small + * MnasNet A1, B1, and small * FBNet A, B, and C * ChamNet (arch details are murky) * Single-Path NAS Pixel1 - * EfficientNetB0-B4 (rest easy to add) + * EfficientNet B0-B5 + * MixNet S, M, L """ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - drop_rate=0., drop_connect_rate=0., act_fn=F.relu, + pad_type='', act_fn=F.relu, drop_rate=0., drop_connect_rate=0., se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=_BN_ARGS_PT, - global_pool='avg', head_conv='default', weight_init='goog', padding_same=False): + global_pool='avg', head_conv='default', weight_init='goog'): super(GenEfficientNet, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -668,16 +666,14 @@ class GenEfficientNet(nn.Module): self.num_features = num_features stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) - self.conv_stem = sconv2d( - in_chans, stem_size, 3, - padding=_padding_arg(1, padding_same), stride=2, bias=False) + self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.bn1 = nn.BatchNorm2d(stem_size, **bn_args) in_chs = stem_size builder = _BlockBuilder( channel_multiplier, channel_divisor, channel_min, - drop_connect_rate, act_fn, se_gate_fn, se_reduce_mid, - bn_args, padding_same, verbose=_DEBUG) + pad_type, act_fn, se_gate_fn, se_reduce_mid, + bn_args, drop_connect_rate, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(in_chs, block_args)) in_chs = builder.in_chs @@ -687,9 +683,7 @@ class GenEfficientNet(nn.Module): assert in_chs == self.num_features else: self.efficient_head = head_conv == 'efficient' - self.conv_head = sconv2d( - in_chs, self.num_features, 1, - padding=_padding_arg(0, padding_same), bias=False) + self.conv_head = select_conv2d(in_chs, self.num_features, 1, padding=pad_type) self.bn2 = None if self.efficient_head else nn.BatchNorm2d(self.num_features, **bn_args) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -919,11 +913,11 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs): """ arch_def = [ # stage 0, 112x112 in - ['ds_r1_k3_s1_e1_c16_are_noskip'], # relu + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu # stage 1, 112x112 in - ['ir_r1_k3_s2_e4_c24_are', 'ir_r1_k3_s1_e3_c24_are'], # relu + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu # stage 2, 56x56 in - ['ir_r3_k5_s2_e3_c40_se0.25_are'], # relu + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu # stage 3, 28x28 in ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish # stage 4, 14x14in @@ -1129,6 +1123,78 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes= return model +def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs): + """Creates a MixNet Small model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model = GenEfficientNet( + _decode_arch_def(arch_def), + num_classes=num_classes, + stem_size=16, + num_features=1536, + channel_multiplier=channel_multiplier, + channel_divisor=8, + channel_min=None, + bn_args=_resolve_bn_args(kwargs), + act_fn=F.relu, + **kwargs + ) + return model + + +def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs): + """Creates a MixNet Medium-Large model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c24'], # relu + # stage 1, 112x112 in + ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model = GenEfficientNet( + _decode_arch_def(arch_def), + num_classes=num_classes, + stem_size=24, + num_features=1536, + channel_multiplier=channel_multiplier, + channel_divisor=8, + channel_min=None, + bn_args=_resolve_bn_args(kwargs), + act_fn=F.relu, + **kwargs + ) + return model + + @register_model def mnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet B1, depth multiplier of 0.5. """ @@ -1440,7 +1506,7 @@ def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """ EfficientNet-B0. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b0'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['padding_same'] = True + kwargs['pad_type'] = 'same' model = _gen_efficientnet( channel_multiplier=1.0, depth_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) @@ -1455,7 +1521,7 @@ def tf_efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """ EfficientNet-B1. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b1'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['padding_same'] = True + kwargs['pad_type'] = 'same' model = _gen_efficientnet( channel_multiplier=1.0, depth_multiplier=1.1, num_classes=num_classes, in_chans=in_chans, **kwargs) @@ -1470,7 +1536,7 @@ def tf_efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """ EfficientNet-B2. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b2'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['padding_same'] = True + kwargs['pad_type'] = 'same' model = _gen_efficientnet( channel_multiplier=1.1, depth_multiplier=1.2, num_classes=num_classes, in_chans=in_chans, **kwargs) @@ -1485,7 +1551,7 @@ def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """ EfficientNet-B3. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b3'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['padding_same'] = True + kwargs['pad_type'] = 'same' model = _gen_efficientnet( channel_multiplier=1.2, depth_multiplier=1.4, num_classes=num_classes, in_chans=in_chans, **kwargs) @@ -1500,7 +1566,7 @@ def tf_efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """ EfficientNet-B4. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b4'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['padding_same'] = True + kwargs['pad_type'] = 'same' model = _gen_efficientnet( channel_multiplier=1.4, depth_multiplier=1.8, num_classes=num_classes, in_chans=in_chans, **kwargs) @@ -1515,7 +1581,7 @@ def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """ EfficientNet-B5. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b5'] kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - kwargs['padding_same'] = True + kwargs['pad_type'] = 'same' model = _gen_efficientnet( channel_multiplier=1.6, depth_multiplier=2.2, num_classes=num_classes, in_chans=in_chans, **kwargs) @@ -1525,5 +1591,89 @@ def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model +def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Creates a MixNet Small model. + """ + default_cfg = default_cfgs['mixnet_m'] + model = _gen_mixnet_s( + channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + #if pretrained: + # load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def mixnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Creates a MixNet Medium model. + """ + default_cfg = default_cfgs['mixnet_m'] + model = _gen_mixnet_m( + channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + #if pretrained: + # load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Creates a MixNet Large model. + """ + default_cfg = default_cfgs['mixnet_l'] + model = _gen_mixnet_m( + channel_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def tf_mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Creates a MixNet Small model. Tensorflow compatible variant + """ + default_cfg = default_cfgs['tf_mixnet_s'] + kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_s( + channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def tf_mixnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Creates a MixNet Medium model. Tensorflow compatible variant + """ + default_cfg = default_cfgs['tf_mixnet_m'] + kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + channel_multiplier=1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def tf_mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Creates a MixNet Large model. Tensorflow compatible variant + """ + default_cfg = default_cfgs['tf_mixnet_l'] + kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + channel_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + def gen_efficientnet_model_names(): return set(_models)