Add new auto-augmentation Tensorflow EfficientNet weights, incl B6 and B7 models. Validation scores still pending but looking good.

pull/23/head
Ross Wightman 5 years ago
parent 857f33015a
commit 77e2e0c4e3

@ -84,24 +84,34 @@ default_cfgs = {
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'efficientnet_b5': _cfg( 'efficientnet_b5': _cfg(
url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'efficientnet_b6': _cfg(
url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'efficientnet_b7': _cfg(
url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'tf_efficientnet_b0': _cfg( 'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
input_size=(3, 224, 224)), input_size=(3, 224, 224)),
'tf_efficientnet_b1': _cfg( 'tf_efficientnet_b1': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b2': _cfg( 'tf_efficientnet_b2': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'tf_efficientnet_b3': _cfg( 'tf_efficientnet_b3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b4': _cfg( 'tf_efficientnet_b4': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b5': _cfg( 'tf_efficientnet_b5': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_aa-99018a74.pth',
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'tf_efficientnet_b6': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'tf_efficientnet_b7': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth',
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'mixnet_s': _cfg(url=''), 'mixnet_s': _cfg(url=''),
'mixnet_m': _cfg( 'mixnet_m': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'),
@ -763,8 +773,6 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
num_classes=num_classes, num_classes=num_classes,
stem_size=32, stem_size=32,
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
**kwargs **kwargs
) )
@ -801,8 +809,6 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
num_classes=num_classes, num_classes=num_classes,
stem_size=32, stem_size=32,
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
**kwargs **kwargs
) )
@ -832,8 +838,6 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
num_classes=num_classes, num_classes=num_classes,
stem_size=8, stem_size=8,
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
**kwargs **kwargs
) )
@ -858,8 +862,6 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
stem_size=32, stem_size=32,
num_features=1024, num_features=1024,
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
act_fn=F.relu6, act_fn=F.relu6,
head_conv='none', head_conv='none',
@ -887,8 +889,6 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
num_classes=num_classes, num_classes=num_classes,
stem_size=32, stem_size=32,
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
act_fn=F.relu6, act_fn=F.relu6,
**kwargs **kwargs
@ -926,8 +926,6 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
num_classes=num_classes, num_classes=num_classes,
stem_size=16, stem_size=16,
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
act_fn=hard_swish, act_fn=hard_swish,
se_gate_fn=hard_sigmoid, se_gate_fn=hard_sigmoid,
@ -961,8 +959,6 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
stem_size=32, stem_size=32,
num_features=1280, # no idea what this is? try mobile/mnasnet default? num_features=1280, # no idea what this is? try mobile/mnasnet default?
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
**kwargs **kwargs
) )
@ -992,8 +988,6 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
stem_size=32, stem_size=32,
num_features=1280, # no idea what this is? try mobile/mnasnet default? num_features=1280, # no idea what this is? try mobile/mnasnet default?
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
**kwargs **kwargs
) )
@ -1024,8 +1018,6 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
stem_size=16, stem_size=16,
num_features=1984, # paper suggests this, but is not 100% clear num_features=1984, # paper suggests this, but is not 100% clear
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
**kwargs **kwargs
) )
@ -1061,8 +1053,6 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
num_classes=num_classes, num_classes=num_classes,
stem_size=32, stem_size=32,
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
**kwargs **kwargs
) )
@ -1107,8 +1097,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
num_classes=num_classes, num_classes=num_classes,
stem_size=32, stem_size=32,
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
num_features=num_features, num_features=num_features,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
act_fn=swish, act_fn=swish,
@ -1144,8 +1132,6 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
stem_size=16, stem_size=16,
num_features=1536, num_features=1536,
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
act_fn=F.relu, act_fn=F.relu,
**kwargs **kwargs
@ -1180,8 +1166,6 @@ def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
stem_size=24, stem_size=24,
num_features=1536, num_features=1536,
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs), bn_args=_resolve_bn_args(kwargs),
act_fn=F.relu, act_fn=F.relu,
**kwargs **kwargs
@ -1495,6 +1479,37 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model return model
@register_model
def efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B6 """
# NOTE for train, drop_rate should be 0.5
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
default_cfg = default_cfgs['efficientnet_b6']
model = _gen_efficientnet(
channel_multiplier=1.8, depth_multiplier=2.6,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B7 """
# NOTE for train, drop_rate should be 0.5
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
default_cfg = default_cfgs['efficientnet_b7']
model = _gen_efficientnet(
channel_multiplier=2.0, depth_multiplier=3.1,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B0. Tensorflow compatible variant """ """ EfficientNet-B0. Tensorflow compatible variant """
@ -1585,6 +1600,38 @@ def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
return model return model
@register_model
def tf_efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B6. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.5
default_cfg = default_cfgs['tf_efficientnet_b6']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
channel_multiplier=1.8, depth_multiplier=2.6,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def tf_efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B7. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.5
default_cfg = default_cfgs['tf_efficientnet_b7']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
channel_multiplier=2.0, depth_multiplier=3.1,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Creates a MixNet Small model. """Creates a MixNet Small model.

Loading…
Cancel
Save