Merge pull request #636 from rwightman/effnetv2_official

EfficientNetV2 official impl w/ weights ported from TF.
pull/647/head
Ross Wightman 3 years ago committed by GitHub
commit 758802e1f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,6 +23,15 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### May 14, 2021
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
* 1k trained variants: `tf_efficientnetv2_s/m/l`
* 21k trained variants: `tf_efficientnetv2_s/m/l_21k`
* 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_21ft1k`
* v2 models w/ v1 scaling: `tf_efficientnet_v2_b0` through `b3`
* Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s`
* Some blank `efficientnetv2_*` models in-place for future native PyTorch training
### May 5, 2021
* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen)
* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit)

@ -1,7 +1,10 @@
""" PyTorch EfficientNet Family
""" The EfficientNet Family in PyTorch
An implementation of EfficienNet that covers variety of related models with efficient architectures:
* EfficientNet-V2
- `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports)
- EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
- CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
@ -22,23 +25,30 @@ An implementation of EfficienNet that covers variety of related models with effi
* And likely more...
Hacked together by / Copyright 2020 Ross Wightman
The majority of the above models (EfficientNet*, MixNet, MnasNet) and original weights were made available
by Mingxing Tan, Quoc Le, and other members of their Google Brain team. Thanks for consistently releasing
the models and weights open source!
Hacked together by / Copyright 2021 Ross Wightman
"""
from functools import partial
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import create_conv2d, create_classifier
from .registry import register_model
__all__ = ['EfficientNet']
__all__ = ['EfficientNet', 'EfficientNetFeatures']
def _cfg(url='', **kwargs):
@ -149,9 +159,20 @@ default_cfgs = {
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'efficientnet_v2s': _cfg(
'efficientnetv2_rw_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth',
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), # FIXME WIP
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
'efficientnetv2_s': _cfg(
url='',
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
'efficientnetv2_m': _cfg(
url='',
input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0),
'efficientnetv2_l': _cfg(
url='',
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
@ -298,6 +319,58 @@ default_cfgs = {
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'),
'tf_efficientnetv2_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
'tf_efficientnetv2_m': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnetv2_l': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnetv2_s_21ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
'tf_efficientnetv2_m_21ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21ft1k-bf41664a.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnetv2_l_21ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21ft1k-60127a9d.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnetv2_s_21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
'tf_efficientnetv2_m_21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnetv2_l_21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnetv2_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth',
input_size=(3, 192, 192), test_input_size=(3, 224, 224), pool_size=(6, 6)),
'tf_efficientnetv2_b1': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth',
input_size=(3, 192, 192), test_input_size=(3, 240, 240), pool_size=(6, 6), crop_pct=0.882),
'tf_efficientnetv2_b2': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b2-847de54e.pth',
input_size=(3, 208, 208), test_input_size=(3, 260, 260), pool_size=(7, 7), crop_pct=0.890),
'tf_efficientnetv2_b3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b3-57773f13.pth',
input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904),
'mixnet_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'),
'mixnet_m': _cfg(
@ -316,13 +389,12 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'),
}
_DEBUG = False
class EfficientNet(nn.Module):
""" (Generic) EfficientNet
A flexible and performant PyTorch implementation of efficient network architectures, including:
* EfficientNet-V2 Small, Medium, Large & B0-B3
* EfficientNet B0-B8, L2
* EfficientNet-EdgeTPU
* EfficientNet-CondConv
@ -333,35 +405,35 @@ class EfficientNet(nn.Module):
"""
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False,
output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None,
se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'):
super(EfficientNet, self).__init__()
norm_kwargs = norm_kwargs or {}
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
se_layer = se_layer or SqueezeExcite
self.num_classes = num_classes
self.num_features = num_features
self.drop_rate = drop_rate
# Stem
if not fix_stem:
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.bn1 = norm_layer(stem_size)
self.act1 = act_layer(inplace=True)
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
head_chs = builder.in_chs
# Head + Pooling
self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type)
self.bn2 = norm_layer(self.num_features, **norm_kwargs)
self.bn2 = norm_layer(self.num_features)
self.act2 = act_layer(inplace=True)
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
@ -408,25 +480,27 @@ class EfficientNetFeatures(nn.Module):
and object detection models.
"""
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
super(EfficientNetFeatures, self).__init__()
norm_kwargs = norm_kwargs or {}
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
se_layer = se_layer or SqueezeExcite
self.drop_rate = drop_rate
# Stem
if not fix_stem:
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.bn1 = norm_layer(stem_size)
self.act1 = act_layer(inplace=True)
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate,
feature_location=feature_location)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
@ -505,8 +579,8 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
stem_size=32,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs
)
model = _create_effnet(variant, pretrained, **model_kwargs)
@ -541,8 +615,8 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
stem_size=32,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs
)
model = _create_effnet(variant, pretrained, **model_kwargs)
@ -570,8 +644,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
stem_size=8,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs
)
model = _create_effnet(variant, pretrained, **model_kwargs)
@ -593,13 +667,14 @@ def _gen_mobilenet_v2(
['ir_r3_k3_s2_e6_c160'],
['ir_r1_k3_s1_e6_c320'],
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None),
num_features=1280 if fix_stem_head else round_chs_fn(1280),
stem_size=32,
fix_stem=fix_stem_head,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=round_chs_fn,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'relu6'),
**kwargs
)
@ -629,8 +704,8 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
block_args=decode_arch_def(arch_def),
stem_size=16,
num_features=1984, # paper suggests this, but is not 100% clear
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs
)
model = _create_effnet(variant, pretrained, **model_kwargs)
@ -664,8 +739,8 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
stem_size=32,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs
)
model = _create_effnet(variant, pretrained, **model_kwargs)
@ -705,13 +780,14 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'],
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=round_channels(1280, channel_multiplier, 8, None),
num_features=round_chs_fn(1280),
stem_size=32,
channel_multiplier=channel_multiplier,
round_chs_fn=round_chs_fn,
act_layer=resolve_act_layer(kwargs, 'swish'),
norm_kwargs=resolve_bn_args(kwargs),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs,
)
model = _create_effnet(variant, pretrained, **model_kwargs)
@ -734,12 +810,13 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
['ir_r4_k5_s1_e8_c144'],
['ir_r2_k5_s2_e8_c192'],
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=round_channels(1280, channel_multiplier, 8, None),
num_features=round_chs_fn(1280),
stem_size=32,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=round_chs_fn,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'relu'),
**kwargs,
)
@ -764,12 +841,13 @@ def _gen_efficientnet_condconv(
]
# NOTE unlike official impl, this one uses `cc<x>` option where x is the base number of experts for each stage and
# the expert_multiplier increases that on a per-model basis as with depth/channel multipliers
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
num_features=round_channels(1280, channel_multiplier, 8, None),
num_features=round_chs_fn(1280),
stem_size=32,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=round_chs_fn,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'swish'),
**kwargs,
)
@ -809,45 +887,137 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
num_features=1280,
stem_size=32,
fix_stem=True,
channel_multiplier=channel_multiplier,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
act_layer=resolve_act_layer(kwargs, 'relu6'),
norm_kwargs=resolve_bn_args(kwargs),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs,
)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
def _gen_efficientnet_v2s(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
""" Creates an EfficientNet-V2s model
NOTE: this is a preliminary definition based on paper, awaiting official code release for details
and weights
def _gen_efficientnetv2_base(
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
""" Creates an EfficientNet-V2 base model
Ref impl:
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
"""
arch_def = [
['cn_r1_k3_s1_e1_c16_skip'],
['er_r2_k3_s2_e4_c32'],
['er_r2_k3_s2_e4_c48'],
['ir_r3_k3_s2_e4_c96_se0.25'],
['ir_r5_k3_s1_e6_c112_se0.25'],
['ir_r8_k3_s2_e6_c192_se0.25'],
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=round_chs_fn(1280),
stem_size=32,
round_chs_fn=round_chs_fn,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'silu'),
**kwargs,
)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
def _gen_efficientnetv2_s(
variant, channel_multiplier=1.0, depth_multiplier=1.0, rw=False, pretrained=False, **kwargs):
""" Creates an EfficientNet-V2 Small model
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
NOTE: `rw` flag sets up 'small' variant to behave like my initial v2 small model,
before ref the impl was released.
"""
arch_def = [
# FIXME it's not clear if the FusedMBConv layers have SE enabled for the Small variant,
# Table 4 suggests no. 23.94M params w/o, 23.98 with which is closer to 24M.
# ['er_r2_k3_s1_e1_c24_se0.25'],
# ['er_r4_k3_s2_e4_c48_se0.25'],
# ['er_r4_k3_s2_e4_c64_se0.25'],
['er_r2_k3_s1_e1_c24'],
['cn_r2_k3_s1_e1_c24_skip'],
['er_r4_k3_s2_e4_c48'],
['er_r4_k3_s2_e4_c64'],
['ir_r6_k3_s2_e4_c128_se0.25'],
['ir_r9_k3_s1_e6_c160_se0.25'],
['ir_r15_k3_s2_e6_c272_se0.25'],
['ir_r15_k3_s2_e6_c256_se0.25'],
]
num_features = 1280
if rw:
# my original variant, based on paper figure differs from the official release
arch_def[0] = ['er_r2_k3_s1_e1_c24']
arch_def[-1] = ['ir_r15_k3_s2_e6_c272_se0.25']
num_features = 1792
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=round_channels(1792, channel_multiplier, 8, None),
num_features=round_chs_fn(num_features),
stem_size=24,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=resolve_act_layer(kwargs, 'silu'), # FIXME this is an assumption, paper does not mention
round_chs_fn=round_chs_fn,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'silu'),
**kwargs,
)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
""" Creates an EfficientNet-V2 Medium model
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
"""
arch_def = [
['cn_r3_k3_s1_e1_c24_skip'],
['er_r5_k3_s2_e4_c48'],
['er_r5_k3_s2_e4_c80'],
['ir_r7_k3_s2_e4_c160_se0.25'],
['ir_r14_k3_s1_e6_c176_se0.25'],
['ir_r18_k3_s2_e6_c304_se0.25'],
['ir_r5_k3_s1_e6_c512_se0.25'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=1280,
stem_size=24,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'silu'),
**kwargs,
)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
""" Creates an EfficientNet-V2 Large model
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
"""
arch_def = [
['cn_r4_k3_s1_e1_c32_skip'],
['er_r7_k3_s2_e4_c64'],
['er_r7_k3_s2_e4_c96'],
['ir_r10_k3_s2_e4_c192_se0.25'],
['ir_r19_k3_s1_e6_c224_se0.25'],
['ir_r25_k3_s2_e6_c384_se0.25'],
['ir_r7_k3_s1_e6_c640_se0.25'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=1280,
stem_size=32,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'silu'),
**kwargs,
)
model = _create_effnet(variant, pretrained, **model_kwargs)
@ -879,8 +1049,8 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
block_args=decode_arch_def(arch_def),
num_features=1536,
stem_size=16,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs
)
model = _create_effnet(variant, pretrained, **model_kwargs)
@ -912,8 +1082,8 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
num_features=1536,
stem_size=24,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs
)
model = _create_effnet(variant, pretrained, **model_kwargs)
@ -1290,13 +1460,35 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
@register_model
def efficientnet_v2s(pretrained=False, **kwargs):
def efficientnetv2_rw_s(pretrained=False, **kwargs):
""" EfficientNet-V2 Small.
NOTE: This is my initial (pre official code release) w/ some differences.
See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding
"""
model = _gen_efficientnetv2_s('efficientnetv2_rw_s', rw=True, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnetv2_s(pretrained=False, **kwargs):
""" EfficientNet-V2 Small. """
model = _gen_efficientnet_v2s(
'efficientnet_v2s', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
model = _gen_efficientnetv2_s('efficientnetv2_s', pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnetv2_m(pretrained=False, **kwargs):
""" EfficientNet-V2 Medium. """
model = _gen_efficientnetv2_m('efficientnetv2_m', pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnetv2_l(pretrained=False, **kwargs):
""" EfficientNet-V2 Large. """
model = _gen_efficientnetv2_l('efficientnetv2_l', pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_b0(pretrained=False, **kwargs):
@ -1708,6 +1900,133 @@ def tf_efficientnet_lite4(pretrained=False, **kwargs):
return model
@register_model
def tf_efficientnetv2_s(pretrained=False, **kwargs):
""" EfficientNet-V2 Small. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_s('tf_efficientnetv2_s', pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_m(pretrained=False, **kwargs):
""" EfficientNet-V2 Medium. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_m('tf_efficientnetv2_m', pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_l(pretrained=False, **kwargs):
""" EfficientNet-V2 Large. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_l('tf_efficientnetv2_l', pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_s_21ft1k(pretrained=False, **kwargs):
""" EfficientNet-V2 Small. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_s('tf_efficientnetv2_s_21ft1k', pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_m_21ft1k(pretrained=False, **kwargs):
""" EfficientNet-V2 Medium. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_m('tf_efficientnetv2_m_21ft1k', pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_l_21ft1k(pretrained=False, **kwargs):
""" EfficientNet-V2 Large. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_l('tf_efficientnetv2_l_21ft1k', pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_s_21k(pretrained=False, **kwargs):
""" EfficientNet-V2 Small w/ ImageNet-21k pretrained weights. Tensorflow compatible variant
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_s('tf_efficientnetv2_s_21k', pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_m_21k(pretrained=False, **kwargs):
""" EfficientNet-V2 Medium w/ ImageNet-21k pretrained weights. Tensorflow compatible variant
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_m('tf_efficientnetv2_m_21k', pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_l_21k(pretrained=False, **kwargs):
""" EfficientNet-V2 Large w/ ImageNet-21k pretrained weights. Tensorflow compatible variant
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_l('tf_efficientnetv2_l_21k', pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_b0(pretrained=False, **kwargs):
""" EfficientNet-V2-B0. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_base('tf_efficientnetv2_b0', pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_b1(pretrained=False, **kwargs):
""" EfficientNet-V2-B1. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_base(
'tf_efficientnetv2_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_b2(pretrained=False, **kwargs):
""" EfficientNet-V2-B2. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_base(
'tf_efficientnetv2_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnetv2_b3(pretrained=False, **kwargs):
""" EfficientNet-V2-B3. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnetv2_base(
'tf_efficientnetv2_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
@register_model
def mixnet_s(pretrained=False, **kwargs):
"""Creates a MixNet Small model.

@ -7,106 +7,34 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from .layers import create_conv2d, drop_path, get_act_layer
from .layers import create_conv2d, drop_path, make_divisible
from .layers.activations import sigmoid
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
# NOTE: momentum varies btw .99 and .9997 depending on source
# .99 in official TF TPU impl
# .9997 (/w .999 in search space) for paper
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
def get_bn_args_tf():
return _BN_ARGS_TF.copy()
def resolve_bn_args(kwargs):
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args
_SE_ARGS_DEFAULT = dict(
gate_fn=sigmoid,
act_layer=None,
reduce_mid=False,
divisor=1)
def resolve_se_args(kwargs, in_chs, act_layer=None):
se_kwargs = kwargs.copy() if kwargs is not None else {}
# fill in args that aren't specified with the defaults
for k, v in _SE_ARGS_DEFAULT.items():
se_kwargs.setdefault(k, v)
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
if not se_kwargs.pop('reduce_mid'):
se_kwargs['reduced_base_chs'] = in_chs
# act_layer override, if it remains None, the containing block's act_layer will be used
if se_kwargs['act_layer'] is None:
assert act_layer is not None
se_kwargs['act_layer'] = act_layer
return se_kwargs
def resolve_act_layer(kwargs, default='relu'):
act_layer = kwargs.pop('act_layer', default)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
return act_layer
def make_divisible(v, divisor=8, min_value=None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
channels *= multiplier
return make_divisible(channels, divisor, channel_min)
class ChannelShuffle(nn.Module):
# FIXME haven't used yet
def __init__(self, groups):
super(ChannelShuffle, self).__init__()
self.groups = groups
def forward(self, x):
"""Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
N, C, H, W = x.size()
g = self.groups
assert C % g == 0, "Incompatible group size {} for input channel {}".format(
g, C
)
return (
x.view(N, g, int(C / g), H, W)
.permute(0, 2, 1, 3, 4)
.contiguous()
.view(N, C, H, W)
)
__all__ = [
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']
class SqueezeExcite(nn.Module):
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
""" Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family
Args:
in_chs (int): input channels to layer
se_ratio (float): ratio of squeeze reduction
act_layer (nn.Module): activation layer of containing block
gate_fn (Callable): attention gate function
block_in_chs (int): input channels of containing block (for calculating reduction from)
reduce_from_block (bool): calculate reduction from block input channels if True
force_act_layer (nn.Module): override block's activation fn if this is set/bound
divisor (int): make reduction channels divisible by this
"""
def __init__(
self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid,
block_in_chs=None, reduce_from_block=True, force_act_layer=None, divisor=1):
super(SqueezeExcite, self).__init__()
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
reduced_chs = (block_in_chs or in_chs) if reduce_from_block else in_chs
reduced_chs = make_divisible(reduced_chs * se_ratio, divisor)
act_layer = force_act_layer or act_layer
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
@ -121,13 +49,16 @@ class SqueezeExcite(nn.Module):
class ConvBnAct(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
""" Conv + Norm Layer + Activation w/ optional skip connection
"""
def __init__(
self, in_chs, out_chs, kernel_size, stride=1, dilation=1, pad_type='',
skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.):
super(ConvBnAct, self).__init__()
norm_kwargs = norm_kwargs or {}
self.has_residual = skip and stride == 1 and in_chs == out_chs
self.drop_path_rate = drop_path_rate
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn1 = norm_layer(out_chs, **norm_kwargs)
self.bn1 = norm_layer(out_chs)
self.act1 = act_layer(inplace=True)
def feature_info(self, location):
@ -138,9 +69,14 @@ class ConvBnAct(nn.Module):
return info
def forward(self, x):
shortcut = x
x = self.conv(x)
x = self.bn1(x)
x = self.act1(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += shortcut
return x
@ -149,31 +85,26 @@ class DepthwiseSeparableConv(nn.Module):
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
norm_kwargs = norm_kwargs or {}
has_se = se_ratio is not None and se_ratio > 0.
has_se = se_layer is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_path_rate = drop_path_rate
self.conv_dw = create_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
self.bn1 = norm_layer(in_chs, **norm_kwargs)
self.bn1 = norm_layer(in_chs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
self.se = se_layer(in_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
self.bn2 = norm_layer(out_chs)
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
def feature_info(self, location):
@ -190,8 +121,7 @@ class DepthwiseSeparableConv(nn.Module):
x = self.bn1(x)
x = self.act1(x)
if self.se is not None:
x = self.se(x)
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
@ -214,41 +144,36 @@ class InvertedResidual(nn.Module):
* MobileNet-V3 - https://arxiv.org/abs/1905.02244
"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
conv_kwargs=None, drop_path_rate=0.):
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
super(InvertedResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_ratio is not None and se_ratio > 0.
has_se = se_layer is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate
# Point-wise expansion
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.bn1 = norm_layer(mid_chs)
self.act1 = act_layer(inplace=True)
# Depth-wise convolution
self.conv_dw = create_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
self.bn2 = norm_layer(mid_chs)
self.act2 = act_layer(inplace=True)
# Squeeze-and-excitation
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
self.se = se_layer(
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
# Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.bn3 = norm_layer(out_chs, **norm_kwargs)
self.bn3 = norm_layer(out_chs)
def feature_info(self, location):
if location == 'expansion': # after SE, input to PWL
@ -271,8 +196,7 @@ class InvertedResidual(nn.Module):
x = self.act2(x)
# Squeeze-and-excitation
if self.se is not None:
x = self.se(x)
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
@ -289,11 +213,10 @@ class InvertedResidual(nn.Module):
class CondConvResidual(InvertedResidual):
""" Inverted residual block w/ CondConv routing"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
num_experts=0, drop_path_rate=0.):
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
self.num_experts = num_experts
conv_kwargs = dict(num_experts=self.num_experts)
@ -301,9 +224,8 @@ class CondConvResidual(InvertedResidual):
super(CondConvResidual, self).__init__(
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
drop_path_rate=drop_path_rate)
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_layer=se_layer,
norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate)
self.routing_fn = nn.Linear(in_chs, self.num_experts)
@ -325,8 +247,7 @@ class CondConvResidual(InvertedResidual):
x = self.act2(x)
# Squeeze-and-excitation
if self.se is not None:
x = self.se(x)
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x, routing_weights)
@ -351,36 +272,32 @@ class EdgeResidual(nn.Module):
* EfficientNet-V2 - https://arxiv.org/abs/2104.00298
"""
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
drop_path_rate=0.):
def __init__(
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='',
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
super(EdgeResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
if fake_in_chs > 0:
mid_chs = make_divisible(fake_in_chs * exp_ratio)
if force_in_chs > 0:
mid_chs = make_divisible(force_in_chs * exp_ratio)
else:
mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_ratio is not None and se_ratio > 0.
has_se = se_layer is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate
# Expansion convolution
self.conv_exp = create_conv2d(
in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.bn1 = norm_layer(mid_chs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
self.se = SqueezeExcite(
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
# Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
self.bn2 = norm_layer(out_chs)
def feature_info(self, location):
if location == 'expansion': # after SE, before PWL
@ -398,8 +315,7 @@ class EdgeResidual(nn.Module):
x = self.act1(x)
# Squeeze-and-excitation
if self.se is not None:
x = self.se(x)
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)

@ -14,13 +14,55 @@ from copy import deepcopy
import torch.nn as nn
from .efficientnet_blocks import *
from .layers import CondConv2d, get_condconv_initializer
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, make_divisible
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
_logger = logging.getLogger(__name__)
_DEBUG_BUILDER = False
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
# NOTE: momentum varies btw .99 and .9997 depending on source
# .99 in official TF TPU impl
# .9997 (/w .999 in search space) for paper
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
def get_bn_args_tf():
return _BN_ARGS_TF.copy()
def resolve_bn_args(kwargs):
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args
def resolve_act_layer(kwargs, default='relu'):
act_layer = kwargs.pop('act_layer', default)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
return act_layer
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit)
def _log_info_if(msg, condition):
if condition:
_logger.info(msg)
@ -63,11 +105,13 @@ def _decode_block_str(block_str):
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
noskip = False
skip = None
for op in ops:
# string options being checked on individual basis, combine if they grow
if op == 'noskip':
noskip = True
skip = False # force no skip connection
elif op == 'skip':
skip = True # force a skip connection
elif op.startswith('n'):
# activation fn
key = op[0]
@ -94,7 +138,7 @@ def _decode_block_str(block_str):
act_layer = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
@ -106,10 +150,10 @@ def _decode_block_str(block_str):
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
noskip=skip is False,
)
if 'cc' in options:
block_args['num_experts'] = int(options['cc'])
@ -119,11 +163,11 @@ def _decode_block_str(block_str):
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
noskip=block_type == 'dsa' or skip is False,
)
elif block_type == 'er':
block_args = dict(
@ -132,11 +176,11 @@ def _decode_block_str(block_str):
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
fake_in_chs=fake_in_chs,
se_ratio=float(options['se']) if 'se' in options else None,
force_in_chs=force_in_chs,
se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
noskip=skip is False,
)
elif block_type == 'cn':
block_args = dict(
@ -145,6 +189,7 @@ def _decode_block_str(block_str):
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
skip=skip is True,
)
else:
assert False, 'Unknown block type (%s)' % block_type
@ -219,19 +264,14 @@ class EfficientNetBuilder:
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
"""
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=None, se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='',
verbose=False):
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels,
act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
self.output_stride = output_stride
self.pad_type = pad_type
self.round_chs_fn = round_chs_fn
self.act_layer = act_layer
self.se_kwargs = se_kwargs
self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs
self.se_layer = se_layer
self.drop_path_rate = drop_path_rate
if feature_location == 'depthwise':
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
@ -239,45 +279,39 @@ class EfficientNetBuilder:
feature_location = 'expansion'
self.feature_location = feature_location
assert feature_location in ('bottleneck', 'expansion', '')
self.verbose = verbose
self.verbose = _DEBUG_BUILDER
# state updated during build, consumed by model
self.in_chs = None
self.features = []
def _round_channels(self, chs):
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
def _make_block(self, ba, block_idx, block_count):
drop_path_rate = self.drop_path_rate * block_idx / block_count
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
# FIXME this is a hack to work around mismatch in origin impl input filters
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['norm_layer'] = self.norm_layer
ba['norm_kwargs'] = self.norm_kwargs
ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
if 'force_in_chs' in ba and ba['force_in_chs']:
# NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None
if bt == 'ir':
ba['norm_layer'] = self.norm_layer
if bt != 'cn':
ba['se_layer'] = self.se_layer
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if bt == 'ir':
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba)
else:
block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = DepthwiseSeparableConv(**ba)
elif bt == 'er':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
_log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = EdgeResidual(**ba)
elif bt == 'cn':
@ -285,8 +319,8 @@ class EfficientNetBuilder:
block = ConvBnAct(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
return block
def __call__(self, in_chs, model_block_args):

@ -13,8 +13,8 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid
from .efficientnet_blocks import SqueezeExcite, ConvBnAct, make_divisible
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible
from .efficientnet_blocks import SqueezeExcite, ConvBnAct
from .helpers import build_model_with_cfg
from .registry import register_model
@ -110,7 +110,6 @@ class GhostBottleneck(nn.Module):
nn.BatchNorm2d(out_chs),
)
def forward(self, x):
shortcut = x

@ -1,10 +1,14 @@
from functools import partial
import torch.nn as nn
from .efficientnet_builder import decode_arch_def, resolve_bn_args
from .mobilenetv3 import MobileNetV3, MobileNetV3Features, build_model_with_cfg, default_cfg_for_features
from .layers import hard_sigmoid
from .efficientnet_blocks import resolve_act_layer
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args
from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import get_act_fn
from .mobilenetv3 import MobileNetV3, MobileNetV3Features
from .registry import register_model
def _cfg(url='', **kwargs):
@ -35,15 +39,15 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
"""
num_features = 1280
se_layer = partial(
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,
stem_size=32,
channel_multiplier=1,
norm_kwargs=resolve_bn_args(kwargs),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
se_layer=se_layer,
**kwargs,
)

@ -22,10 +22,10 @@ to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def make_divisible(v, divisor=8, min_value=None):
def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
if new_v < round_limit * v:
new_v += divisor
return new_v
return new_v

@ -5,23 +5,25 @@ A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2021 Ross Wightman
"""
from functools import partial
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
from .registry import register_model
__all__ = ['MobileNetV3']
__all__ = ['MobileNetV3', 'MobileNetV3Features']
def _cfg(url='', **kwargs):
@ -47,9 +49,11 @@ default_cfgs = {
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_in21k_miil.pth', num_classes=11221),
'mobilenetv3_small_075': _cfg(url=''),
'mobilenetv3_small_100': _cfg(url=''),
'mobilenetv3_rw': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
interpolation='bicubic'),
'tf_mobilenetv3_large_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
@ -70,8 +74,6 @@ default_cfgs = {
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
}
_DEBUG = False
class MobileNetV3(nn.Module):
""" MobiletNet-V3
@ -84,24 +86,26 @@ class MobileNetV3(nn.Module):
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
pad_type='', act_layer=None, norm_layer=None, se_layer=None,
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
se_layer = se_layer or SqueezeExcite
self.num_classes = num_classes
self.num_features = num_features
self.drop_rate = drop_rate
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.bn1 = norm_layer(stem_size)
self.act1 = act_layer(inplace=True)
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
head_chs = builder.in_chs
@ -158,23 +162,25 @@ class MobileNetV3Features(nn.Module):
"""
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
in_chans=3, stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels,
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
super(MobileNetV3Features, self).__init__()
norm_kwargs = norm_kwargs or {}
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
se_layer = se_layer or SqueezeExcite
self.drop_rate = drop_rate
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.bn1 = norm_layer(stem_size)
self.act1 = act_layer(inplace=True)
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,
drop_path_rate=drop_path_rate, feature_location=feature_location)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
@ -253,10 +259,10 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
head_bias=False,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1),
se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), reduce_from_block=False),
**kwargs,
)
model = _create_mnv3(variant, pretrained, **model_kwargs)
@ -344,15 +350,16 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
se_layer = partial(
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,
stem_size=16,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=act_layer,
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
se_layer=se_layer,
**kwargs,
)
model = _create_mnv3(variant, pretrained, **model_kwargs)

@ -1 +1 @@
__version__ = '0.4.8'
__version__ = '0.4.9'

Loading…
Cancel
Save