Experimenting with a custom MixNet-XL and MixNet-XXL definition

pull/30/head
Ross Wightman 5 years ago
parent 9816ca3ab4
commit 51a2375b0c

@ -138,6 +138,8 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'),
'mixnet_l': _cfg( 'mixnet_l': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'),
'mixnet_xl': _cfg(),
'mixnet_xxl': _cfg(),
'tf_mixnet_s': _cfg( 'tf_mixnet_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'),
'tf_mixnet_m': _cfg( 'tf_mixnet_m': _cfg(
@ -312,21 +314,59 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
else: else:
assert False, 'Unknown block type (%s)' % block_type assert False, 'Unknown block type (%s)' % block_type
# return a list of block args expanded by num_repeat and return block_args, num_repeat
# scaled by depth_multiplier
num_repeat = int(math.ceil(num_repeat * depth_multiplier))
return [deepcopy(block_args) for _ in range(num_repeat)]
def _decode_arch_def(arch_def, depth_multiplier=1.0): def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
""" Per-stage depth scaling
Scales the block repeats in each stage. This depth scaling impl maintains
compatibility with the EfficientNet scaling method, while allowing sensible
scaling for other models that may have multiple block arg definitions in each stage.
"""
# We scale the total repeat count for each stage, there may be multiple
# block arg defs per stage so we need to sum.
num_repeat = sum(repeats)
if depth_trunc == 'round':
# Truncating to int by rounding allows stages with few repeats to remain
# proportionally smaller for longer. This is a good choice when stage definitions
# include single repeat stages that we'd prefer to keep that way as long as possible
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
else:
# The default for EfficientNet truncates repeats to int via 'ceil'.
# Any multiplier > 1.0 will result in an increased depth for every stage.
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
# Proportionally distribute repeat count scaling to each block definition in the stage.
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
# The first block makes less sense to repeat in most of the arch definitions.
repeats_scaled = []
for r in repeats[::-1]:
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
repeats_scaled.append(rs)
num_repeat -= r
num_repeat_scaled -= rs
repeats_scaled = repeats_scaled[::-1]
# Apply the calculated scaling to each block arg in the stage
sa_scaled = []
for ba, rep in zip(stack_args, repeats_scaled):
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
return sa_scaled
def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
arch_args = [] arch_args = []
for stack_idx, block_strings in enumerate(arch_def): for stack_idx, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list) assert isinstance(block_strings, list)
stack_args = [] stack_args = []
repeats = []
for block_str in block_strings: for block_str in block_strings:
assert isinstance(block_str, str) assert isinstance(block_str, str)
stack_args.extend(_decode_block_str(block_str, depth_multiplier)) ba, rep = _decode_block_str(block_str)
arch_args.append(stack_args) stack_args.append(ba)
repeats.append(rep)
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
return arch_args return arch_args
@ -1261,7 +1301,7 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
return model return model
def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs): def _gen_mixnet_m(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
"""Creates a MixNet Medium-Large model. """Creates a MixNet Medium-Large model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
@ -1283,7 +1323,7 @@ def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
# 7x7 # 7x7
] ]
model = GenEfficientNet( model = GenEfficientNet(
_decode_arch_def(arch_def), _decode_arch_def(arch_def, depth_multiplier=depth_multiplier, depth_trunc='round'),
num_classes=num_classes, num_classes=num_classes,
stem_size=24, stem_size=24,
num_features=1536, num_features=1536,
@ -1876,6 +1916,33 @@ def mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model return model
@register_model
def mixnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Creates a MixNet Extra-Large model.
"""
default_cfg = default_cfgs['mixnet_xl']
#kwargs['drop_connect_rate'] = 0.2
model = _gen_mixnet_m(
channel_multiplier=1.6, depth_multiplier=1.2, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def mixnet_xxl(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Creates a MixNet Double Extra Large model.
"""
default_cfg = default_cfgs['mixnet_xxl']
model = _gen_mixnet_m(
channel_multiplier=2.4, depth_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def tf_mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def tf_mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Creates a MixNet Small model. Tensorflow compatible variant """Creates a MixNet Small model. Tensorflow compatible variant

Loading…
Cancel
Save