Prep for effcientnetv2_rw_m model weights that started training before official release..

pull/660/head
Ross Wightman 4 years ago
parent 22f7c6760f
commit c2ba229d99

@ -162,6 +162,9 @@ default_cfgs = {
'efficientnetv2_rw_s': _cfg( 'efficientnetv2_rw_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth',
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
'efficientnetv2_rw_m': _cfg(
url='',
input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0),
'efficientnetv2_s': _cfg( 'efficientnetv2_s': _cfg(
url='', url='',
@ -173,7 +176,6 @@ default_cfgs = {
url='', url='',
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnet_b0': _cfg( 'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
input_size=(3, 224, 224)), input_size=(3, 224, 224)),
@ -1461,7 +1463,7 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
@register_model @register_model
def efficientnetv2_rw_s(pretrained=False, **kwargs): def efficientnetv2_rw_s(pretrained=False, **kwargs):
""" EfficientNet-V2 Small. """ EfficientNet-V2 Small RW variant.
NOTE: This is my initial (pre official code release) w/ some differences. NOTE: This is my initial (pre official code release) w/ some differences.
See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding
""" """
@ -1469,6 +1471,16 @@ def efficientnetv2_rw_s(pretrained=False, **kwargs):
return model return model
@register_model
def efficientnetv2_rw_m(pretrained=False, **kwargs):
""" EfficientNet-V2 Medium RW variant.
"""
model = _gen_efficientnetv2_s(
'efficientnetv2_rw_m', channel_multiplier=1.2, depth_multiplier=(1.2,) * 4 + (1.6,) * 2, rw=True,
pretrained=pretrained, **kwargs)
return model
@register_model @register_model
def efficientnetv2_s(pretrained=False, **kwargs): def efficientnetv2_s(pretrained=False, **kwargs):
""" EfficientNet-V2 Small. """ """ EfficientNet-V2 Small. """

@ -237,7 +237,11 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
arch_args = [] arch_args = []
for stack_idx, block_strings in enumerate(arch_def): if isinstance(depth_multiplier, tuple):
assert len(depth_multiplier) == len(arch_def)
else:
depth_multiplier = (depth_multiplier,) * len(arch_def)
for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)):
assert isinstance(block_strings, list) assert isinstance(block_strings, list)
stack_args = [] stack_args = []
repeats = [] repeats = []
@ -251,7 +255,7 @@ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
else: else:
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc))
return arch_args return arch_args

Loading…
Cancel
Save