|
|
|
@ -2,6 +2,9 @@
|
|
|
|
|
|
|
|
|
|
An implementation of EfficienNet that covers variety of related models with efficient architectures:
|
|
|
|
|
|
|
|
|
|
* EfficientNet-V2
|
|
|
|
|
- `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
|
|
|
|
|
|
|
|
|
|
* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports)
|
|
|
|
|
- EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
|
|
|
|
|
- CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
|
|
|
|
@ -22,23 +25,26 @@ An implementation of EfficienNet that covers variety of related models with effi
|
|
|
|
|
|
|
|
|
|
* And likely more...
|
|
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
|
Hacked together by / Copyright 2021 Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
|
|
|
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
|
|
|
|
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
|
|
|
|
from .efficientnet_blocks import SqueezeExcite
|
|
|
|
|
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
|
|
|
|
|
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
|
|
|
|
from .features import FeatureInfo, FeatureHooks
|
|
|
|
|
from .helpers import build_model_with_cfg, default_cfg_for_features
|
|
|
|
|
from .layers import create_conv2d, create_classifier
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
__all__ = ['EfficientNet']
|
|
|
|
|
__all__ = ['EfficientNet', 'EfficientNetFeatures']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
@ -149,9 +155,20 @@ default_cfgs = {
|
|
|
|
|
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth',
|
|
|
|
|
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
|
|
|
|
|
|
|
|
|
'efficientnet_v2s': _cfg(
|
|
|
|
|
'efficientnetv2_rw_s': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth',
|
|
|
|
|
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), # FIXME WIP
|
|
|
|
|
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'efficientnetv2_s': _cfg(
|
|
|
|
|
url='',
|
|
|
|
|
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
|
|
|
|
|
'efficientnetv2_m': _cfg(
|
|
|
|
|
url='',
|
|
|
|
|
input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0),
|
|
|
|
|
'efficientnetv2_l': _cfg(
|
|
|
|
|
url='',
|
|
|
|
|
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'tf_efficientnet_b0': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
|
|
|
@ -298,6 +315,58 @@ default_cfgs = {
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'),
|
|
|
|
|
|
|
|
|
|
'tf_efficientnetv2_s': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
|
|
|
|
|
'tf_efficientnetv2_m': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
|
|
|
|
'tf_efficientnetv2_l': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'tf_efficientnetv2_s_21kft1k': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21kft1k-d7dafa41.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
|
|
|
|
|
'tf_efficientnetv2_m_21kft1k': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21kft1k-bf41664a.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
|
|
|
|
'tf_efficientnetv2_l_21kft1k': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21kft1k-60127a9d.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'tf_efficientnetv2_s_21k': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
|
|
|
|
|
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
|
|
|
|
|
'tf_efficientnetv2_m_21k': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
|
|
|
|
|
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
|
|
|
|
'tf_efficientnetv2_l_21k': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
|
|
|
|
|
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'tf_efficientnetv2_b0': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth',
|
|
|
|
|
input_size=(3, 192, 192), test_input_size=(3, 224, 224), pool_size=(6, 6)),
|
|
|
|
|
'tf_efficientnetv2_b1': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth',
|
|
|
|
|
input_size=(3, 192, 192), test_input_size=(3, 240, 240), pool_size=(6, 6), crop_pct=0.882),
|
|
|
|
|
'tf_efficientnetv2_b2': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b2-847de54e.pth',
|
|
|
|
|
input_size=(3, 208, 208), test_input_size=(3, 260, 260), pool_size=(7, 7), crop_pct=0.890),
|
|
|
|
|
'tf_efficientnetv2_b3': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b3-57773f13.pth',
|
|
|
|
|
input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904),
|
|
|
|
|
|
|
|
|
|
'mixnet_s': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'),
|
|
|
|
|
'mixnet_m': _cfg(
|
|
|
|
@ -316,13 +385,12 @@ default_cfgs = {
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_DEBUG = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EfficientNet(nn.Module):
|
|
|
|
|
""" (Generic) EfficientNet
|
|
|
|
|
|
|
|
|
|
A flexible and performant PyTorch implementation of efficient network architectures, including:
|
|
|
|
|
* EfficientNet-V2 Small, Medium, Large & B0-B3
|
|
|
|
|
* EfficientNet B0-B8, L2
|
|
|
|
|
* EfficientNet-EdgeTPU
|
|
|
|
|
* EfficientNet-CondConv
|
|
|
|
@ -333,35 +401,35 @@ class EfficientNet(nn.Module):
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
|
|
|
|
|
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
|
|
|
|
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
|
|
|
|
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
|
|
|
|
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False,
|
|
|
|
|
output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None,
|
|
|
|
|
se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'):
|
|
|
|
|
super(EfficientNet, self).__init__()
|
|
|
|
|
norm_kwargs = norm_kwargs or {}
|
|
|
|
|
|
|
|
|
|
act_layer = act_layer or nn.ReLU
|
|
|
|
|
norm_layer = norm_layer or nn.BatchNorm2d
|
|
|
|
|
se_layer = se_layer or SqueezeExcite
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.num_features = num_features
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
|
|
|
|
|
# Stem
|
|
|
|
|
if not fix_stem:
|
|
|
|
|
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
|
|
|
|
stem_size = round_chs_fn(stem_size)
|
|
|
|
|
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
|
|
|
|
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
|
|
|
|
self.bn1 = norm_layer(stem_size)
|
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
# Middle stages (IR/ER/DS Blocks)
|
|
|
|
|
builder = EfficientNetBuilder(
|
|
|
|
|
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
|
|
|
|
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
|
|
|
|
|
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
|
|
|
|
|
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
|
|
|
|
|
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
|
|
|
|
self.feature_info = builder.features
|
|
|
|
|
head_chs = builder.in_chs
|
|
|
|
|
|
|
|
|
|
# Head + Pooling
|
|
|
|
|
self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type)
|
|
|
|
|
self.bn2 = norm_layer(self.num_features, **norm_kwargs)
|
|
|
|
|
self.bn2 = norm_layer(self.num_features)
|
|
|
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
|
self.global_pool, self.classifier = create_classifier(
|
|
|
|
|
self.num_features, self.num_classes, pool_type=global_pool)
|
|
|
|
@ -408,25 +476,27 @@ class EfficientNetFeatures(nn.Module):
|
|
|
|
|
and object detection models.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
|
|
|
|
|
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
|
|
|
|
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
|
|
|
|
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
|
|
|
|
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
|
|
|
|
|
stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
|
|
|
|
|
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
|
|
|
|
|
super(EfficientNetFeatures, self).__init__()
|
|
|
|
|
norm_kwargs = norm_kwargs or {}
|
|
|
|
|
act_layer = act_layer or nn.ReLU
|
|
|
|
|
norm_layer = norm_layer or nn.BatchNorm2d
|
|
|
|
|
se_layer = se_layer or SqueezeExcite
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
|
|
|
|
|
# Stem
|
|
|
|
|
if not fix_stem:
|
|
|
|
|
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
|
|
|
|
stem_size = round_chs_fn(stem_size)
|
|
|
|
|
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
|
|
|
|
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
|
|
|
|
self.bn1 = norm_layer(stem_size)
|
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
# Middle stages (IR/ER/DS Blocks)
|
|
|
|
|
builder = EfficientNetBuilder(
|
|
|
|
|
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
|
|
|
|
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
|
|
|
|
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
|
|
|
|
|
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate,
|
|
|
|
|
feature_location=feature_location)
|
|
|
|
|
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
|
|
|
|
self.feature_info = FeatureInfo(builder.features, out_indices)
|
|
|
|
|
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
|
|
|
|
@ -505,8 +575,8 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def),
|
|
|
|
|
stem_size=32,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
@ -541,8 +611,8 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def),
|
|
|
|
|
stem_size=32,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
@ -570,8 +640,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def),
|
|
|
|
|
stem_size=8,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
@ -593,13 +663,14 @@ def _gen_mobilenet_v2(
|
|
|
|
|
['ir_r3_k3_s2_e6_c160'],
|
|
|
|
|
['ir_r1_k3_s1_e6_c320'],
|
|
|
|
|
]
|
|
|
|
|
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
|
|
|
|
|
num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None),
|
|
|
|
|
num_features=1280 if fix_stem_head else round_chs_fn(1280),
|
|
|
|
|
stem_size=32,
|
|
|
|
|
fix_stem=fix_stem_head,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
round_chs_fn=round_chs_fn,
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
act_layer=resolve_act_layer(kwargs, 'relu6'),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
@ -629,8 +700,8 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
|
|
|
|
block_args=decode_arch_def(arch_def),
|
|
|
|
|
stem_size=16,
|
|
|
|
|
num_features=1984, # paper suggests this, but is not 100% clear
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
@ -664,8 +735,8 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def),
|
|
|
|
|
stem_size=32,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
@ -705,13 +776,14 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
|
|
|
|
|
['ir_r4_k5_s2_e6_c192_se0.25'],
|
|
|
|
|
['ir_r1_k3_s1_e6_c320_se0.25'],
|
|
|
|
|
]
|
|
|
|
|
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def, depth_multiplier),
|
|
|
|
|
num_features=round_channels(1280, channel_multiplier, 8, None),
|
|
|
|
|
num_features=round_chs_fn(1280),
|
|
|
|
|
stem_size=32,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
round_chs_fn=round_chs_fn,
|
|
|
|
|
act_layer=resolve_act_layer(kwargs, 'swish'),
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
@ -734,12 +806,13 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
|
|
|
|
['ir_r4_k5_s1_e8_c144'],
|
|
|
|
|
['ir_r2_k5_s2_e8_c192'],
|
|
|
|
|
]
|
|
|
|
|
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def, depth_multiplier),
|
|
|
|
|
num_features=round_channels(1280, channel_multiplier, 8, None),
|
|
|
|
|
num_features=round_chs_fn(1280),
|
|
|
|
|
stem_size=32,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
round_chs_fn=round_chs_fn,
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
act_layer=resolve_act_layer(kwargs, 'relu'),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
@ -764,12 +837,13 @@ def _gen_efficientnet_condconv(
|
|
|
|
|
]
|
|
|
|
|
# NOTE unlike official impl, this one uses `cc<x>` option where x is the base number of experts for each stage and
|
|
|
|
|
# the expert_multiplier increases that on a per-model basis as with depth/channel multipliers
|
|
|
|
|
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
|
|
|
|
|
num_features=round_channels(1280, channel_multiplier, 8, None),
|
|
|
|
|
num_features=round_chs_fn(1280),
|
|
|
|
|
stem_size=32,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
round_chs_fn=round_chs_fn,
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
act_layer=resolve_act_layer(kwargs, 'swish'),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
@ -809,45 +883,137 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
|
|
|
|
num_features=1280,
|
|
|
|
|
stem_size=32,
|
|
|
|
|
fix_stem=True,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
|
|
|
|
act_layer=resolve_act_layer(kwargs, 'relu6'),
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gen_efficientnet_v2s(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
|
|
|
|
""" Creates an EfficientNet-V2s model
|
|
|
|
|
|
|
|
|
|
NOTE: this is a preliminary definition based on paper, awaiting official code release for details
|
|
|
|
|
and weights
|
|
|
|
|
def _gen_efficientnetv2_base(
|
|
|
|
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
|
|
|
|
""" Creates an EfficientNet-V2 base model
|
|
|
|
|
|
|
|
|
|
Ref impl:
|
|
|
|
|
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
|
|
|
|
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
|
|
|
|
|
"""
|
|
|
|
|
arch_def = [
|
|
|
|
|
['cn_r1_k3_s1_e1_c16_skip'],
|
|
|
|
|
['er_r2_k3_s2_e4_c32'],
|
|
|
|
|
['er_r2_k3_s2_e4_c48'],
|
|
|
|
|
['ir_r3_k3_s2_e4_c96_se0.25'],
|
|
|
|
|
['ir_r5_k3_s1_e6_c112_se0.25'],
|
|
|
|
|
['ir_r8_k3_s2_e6_c192_se0.25'],
|
|
|
|
|
]
|
|
|
|
|
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def, depth_multiplier),
|
|
|
|
|
num_features=round_chs_fn(1280),
|
|
|
|
|
stem_size=32,
|
|
|
|
|
round_chs_fn=round_chs_fn,
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
act_layer=resolve_act_layer(kwargs, 'silu'),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gen_efficientnetv2_s(
|
|
|
|
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, rw=False, pretrained=False, **kwargs):
|
|
|
|
|
""" Creates an EfficientNet-V2 Small model
|
|
|
|
|
|
|
|
|
|
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
|
|
|
|
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
|
|
|
|
|
|
|
|
|
|
NOTE: `rw` flag sets up 'small' variant to behave like my initial v2 small model,
|
|
|
|
|
before ref the impl was released.
|
|
|
|
|
"""
|
|
|
|
|
arch_def = [
|
|
|
|
|
# FIXME it's not clear if the FusedMBConv layers have SE enabled for the Small variant,
|
|
|
|
|
# Table 4 suggests no. 23.94M params w/o, 23.98 with which is closer to 24M.
|
|
|
|
|
# ['er_r2_k3_s1_e1_c24_se0.25'],
|
|
|
|
|
# ['er_r4_k3_s2_e4_c48_se0.25'],
|
|
|
|
|
# ['er_r4_k3_s2_e4_c64_se0.25'],
|
|
|
|
|
['er_r2_k3_s1_e1_c24'],
|
|
|
|
|
['cn_r2_k3_s1_e1_c24_skip'],
|
|
|
|
|
['er_r4_k3_s2_e4_c48'],
|
|
|
|
|
['er_r4_k3_s2_e4_c64'],
|
|
|
|
|
['ir_r6_k3_s2_e4_c128_se0.25'],
|
|
|
|
|
['ir_r9_k3_s1_e6_c160_se0.25'],
|
|
|
|
|
['ir_r15_k3_s2_e6_c272_se0.25'],
|
|
|
|
|
['ir_r15_k3_s2_e6_c256_se0.25'],
|
|
|
|
|
]
|
|
|
|
|
num_features = 1280
|
|
|
|
|
if rw:
|
|
|
|
|
# my original variant, based on paper figure differs from the official release
|
|
|
|
|
arch_def[0] = ['er_r2_k3_s1_e1_c24']
|
|
|
|
|
arch_def[-1] = ['ir_r15_k3_s2_e6_c272_se0.25']
|
|
|
|
|
num_features = 1792
|
|
|
|
|
|
|
|
|
|
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def, depth_multiplier),
|
|
|
|
|
num_features=round_chs_fn(num_features),
|
|
|
|
|
stem_size=24,
|
|
|
|
|
round_chs_fn=round_chs_fn,
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
act_layer=resolve_act_layer(kwargs, 'silu'),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
|
|
|
|
""" Creates an EfficientNet-V2 Medium model
|
|
|
|
|
|
|
|
|
|
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
|
|
|
|
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
arch_def = [
|
|
|
|
|
['cn_r3_k3_s1_e1_c24_skip'],
|
|
|
|
|
['er_r5_k3_s2_e4_c48'],
|
|
|
|
|
['er_r5_k3_s2_e4_c80'],
|
|
|
|
|
['ir_r7_k3_s2_e4_c160_se0.25'],
|
|
|
|
|
['ir_r14_k3_s1_e6_c176_se0.25'],
|
|
|
|
|
['ir_r18_k3_s2_e6_c304_se0.25'],
|
|
|
|
|
['ir_r5_k3_s1_e6_c512_se0.25'],
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def, depth_multiplier),
|
|
|
|
|
num_features=round_channels(1792, channel_multiplier, 8, None),
|
|
|
|
|
num_features=1280,
|
|
|
|
|
stem_size=24,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
act_layer=resolve_act_layer(kwargs, 'silu'), # FIXME this is an assumption, paper does not mention
|
|
|
|
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
act_layer=resolve_act_layer(kwargs, 'silu'),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
|
|
|
|
""" Creates an EfficientNet-V2 Large model
|
|
|
|
|
|
|
|
|
|
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
|
|
|
|
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
arch_def = [
|
|
|
|
|
['cn_r4_k3_s1_e1_c32_skip'],
|
|
|
|
|
['er_r7_k3_s2_e4_c64'],
|
|
|
|
|
['er_r7_k3_s2_e4_c96'],
|
|
|
|
|
['ir_r10_k3_s2_e4_c192_se0.25'],
|
|
|
|
|
['ir_r19_k3_s1_e6_c224_se0.25'],
|
|
|
|
|
['ir_r25_k3_s2_e6_c384_se0.25'],
|
|
|
|
|
['ir_r7_k3_s1_e6_c640_se0.25'],
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
block_args=decode_arch_def(arch_def, depth_multiplier),
|
|
|
|
|
num_features=1280,
|
|
|
|
|
stem_size=32,
|
|
|
|
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
act_layer=resolve_act_layer(kwargs, 'silu'),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
@ -879,8 +1045,8 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
|
|
|
|
block_args=decode_arch_def(arch_def),
|
|
|
|
|
num_features=1536,
|
|
|
|
|
stem_size=16,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
@ -912,8 +1078,8 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
|
|
|
|
|
block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
|
|
|
|
|
num_features=1536,
|
|
|
|
|
stem_size=24,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
norm_kwargs=resolve_bn_args(kwargs),
|
|
|
|
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
|
|
|
|
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
model = _create_effnet(variant, pretrained, **model_kwargs)
|
|
|
|
@ -1290,13 +1456,35 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def efficientnet_v2s(pretrained=False, **kwargs):
|
|
|
|
|
def efficientnetv2_rw_s(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Small.
|
|
|
|
|
NOTE: This is my initial (pre official code release) w/ some differences.
|
|
|
|
|
See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding
|
|
|
|
|
"""
|
|
|
|
|
model = _gen_efficientnetv2_s('efficientnetv2_rw_s', rw=True, pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def efficientnetv2_s(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Small. """
|
|
|
|
|
model = _gen_efficientnet_v2s(
|
|
|
|
|
'efficientnet_v2s', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
|
|
|
|
model = _gen_efficientnetv2_s('efficientnetv2_s', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def efficientnetv2_m(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Medium. """
|
|
|
|
|
model = _gen_efficientnetv2_m('efficientnetv2_m', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def efficientnetv2_l(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Large. """
|
|
|
|
|
model = _gen_efficientnetv2_l('efficientnetv2_l', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnet_b0(pretrained=False, **kwargs):
|
|
|
|
@ -1708,6 +1896,127 @@ def tf_efficientnet_lite4(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_s(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Small. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_s('tf_efficientnetv2_s', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_m(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Medium. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_m('tf_efficientnetv2_m', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_l(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Large. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_l('tf_efficientnetv2_l', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_s_21kft1k(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Small. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_s('tf_efficientnetv2_s_21kft1k', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_m_21kft1k(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Medium. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_m('tf_efficientnetv2_m_21kft1k', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_l_21kft1k(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Large. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_l('tf_efficientnetv2_l_21kft1k', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_s_21k(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Small w/ ImageNet-21k pretrained weights. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_s('tf_efficientnetv2_s_21k', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_m_21k(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Medium w/ ImageNet-21k pretrained weights. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_m('tf_efficientnetv2_m_21k', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_l_21k(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2 Large w/ ImageNet-21k pretrained weights. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_l('tf_efficientnetv2_l_21k', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_b0(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2-B0. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_base('tf_efficientnetv2_b0', pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_b1(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2-B1. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_base(
|
|
|
|
|
'tf_efficientnetv2_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_b2(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2-B2. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_base(
|
|
|
|
|
'tf_efficientnetv2_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnetv2_b3(pretrained=False, **kwargs):
|
|
|
|
|
""" EfficientNet-V2-B3. Tensorflow compatible variant """
|
|
|
|
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnetv2_base(
|
|
|
|
|
'tf_efficientnetv2_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def mixnet_s(pretrained=False, **kwargs):
|
|
|
|
|
"""Creates a MixNet Small model.
|
|
|
|
|