Merge branch 'yoniaflalo-adding_ECA_resnet'

pull/136/head
Ross Wightman 4 years ago
commit 7a9942a75e

@ -1,8 +1,11 @@
import torch
import torch.nn as nn
from copy import deepcopy
import torch.utils.model_zoo as model_zoo
import os
import logging
from collections import OrderedDict
from timm.models.layers.conv2d_same import Conv2dSame
def load_state_dict(checkpoint_path, use_ema=False):
@ -98,7 +101,96 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
model.load_state_dict(state_dict, strict=strict)
def extract_layer(model, layer):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
if not hasattr(model, 'module') and layer[0] == 'module':
layer = layer[1:]
for l in layer:
if hasattr(module, l):
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
else:
return module
return module
def set_layer(model, layer, val):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
lst_index = 0
module2 = module
for l in layer:
if hasattr(module2, l):
if not l.isdigit():
module2 = getattr(module2, l)
else:
module2 = module2[int(l)]
lst_index += 1
lst_index -= 1
for l in layer[:lst_index]:
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
l = layer[lst_index]
setattr(module, l, val)
def adapt_model_from_string(parent_module, model_string):
separator = '***'
state_dict = {}
lst_shape = model_string.split(separator)
for k in lst_shape:
k = k.split(':')
key = k[0]
shape = k[1][1:-1].split(',')
if shape[0] != '':
state_dict[key] = [int(i) for i in shape]
new_module = deepcopy(parent_module)
for n, m in parent_module.named_modules():
old_module = extract_layer(parent_module, n)
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
if isinstance(old_module, Conv2dSame):
conv = Conv2dSame
else:
conv = nn.Conv2d
s = state_dict[n + '.weight']
in_channels = s[1]
out_channels = s[0]
g = 1
if old_module.groups > 1:
in_channels = out_channels
g = in_channels
new_conv = conv(
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
groups=g, stride=old_module.stride)
set_layer(new_module, n, new_conv)
if isinstance(old_module, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d(
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
set_layer(new_module, n, new_bn)
if isinstance(old_module, nn.Linear):
new_fc = nn.Linear(
in_features=state_dict[n + '.weight'][1], out_features=old_module.out_features,
bias=old_module.bias is not None)
set_layer(new_module, n, new_fc)
new_module.eval()
parent_module.eval()
return new_module
def adapt_model_from_file(parent_module, model_variant):
adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
with open(adapt_file, 'r') as f:
return adapt_model_from_string(parent_module, f.read().strip())

File diff suppressed because one or more lines are too long

@ -0,0 +1 @@
conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022]

@ -11,11 +11,10 @@ import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
@ -104,6 +103,21 @@ default_cfgs = {
interpolation='bicubic'),
'ecaresnet18': _cfg(),
'ecaresnet50': _cfg(),
'ecaresnetlight': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth',
interpolation='bicubic'),
'ecaresnet50d': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth',
interpolation='bicubic'),
'ecaresnet50d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth',
interpolation='bicubic'),
'ecaresnet101d': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth',
interpolation='bicubic'),
'ecaresnet101d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
interpolation='bicubic'),
}
@ -1022,3 +1036,81 @@ def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50-D model with eca.
"""
default_cfg = default_cfgs['ecaresnet50d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnet50d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50-D model pruned with eca.
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
"""
variant = 'ecaresnet50d_pruned'
default_cfg = default_cfgs[variant]
model = ResNet(
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
model.default_cfg = default_cfg
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnetlight(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50-D light model with eca.
"""
default_cfg = default_cfgs['ecaresnetlight']
model = ResNet(
Bottleneck, [1, 1, 11, 3], stem_width=32, avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnet101d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-101-D model with eca.
"""
default_cfg = default_cfgs['ecaresnet101d']
model = ResNet(
Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-101-D model pruned with eca.
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
"""
variant = 'ecaresnet101d_pruned'
default_cfg = default_cfgs[variant]
model = ResNet(
Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
model.default_cfg = default_cfg
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

Loading…
Cancel
Save