Merge branch 'adding_ECA_resnet' of https://github.com/yoniaflalo/pytorch-image-models into yoniaflalo-adding_ECA_resnet

pull/136/head
Ross Wightman 4 years ago
commit ba793f5c1c

@ -1,8 +1,11 @@
import torch
import torch.nn as nn
from copy import deepcopy
import torch.utils.model_zoo as model_zoo
import os
import logging
from collections import OrderedDict
from timm.models.layers.conv2d_same import Conv2dSame
def load_state_dict(checkpoint_path, use_ema=False):
@ -101,4 +104,91 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
def extract_layer(model, layer):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
if not hasattr(model, 'module') and layer[0] == 'module':
layer = layer[1:]
for l in layer:
if hasattr(module, l):
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
else:
return module
return module
def set_layer(model, layer, val):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
lst_index = 0
module2 = module
for l in layer:
if hasattr(module2, l):
if not l.isdigit():
module2 = getattr(module2, l)
else:
module2 = module2[int(l)]
lst_index += 1
lst_index -= 1
for l in layer[:lst_index]:
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
l = layer[lst_index]
setattr(module, l, val)
def adapt_model_from_string(parent_module, model_string):
separator = '***'
state_dict = {}
lst_shape = model_string.split(separator)
for k in lst_shape:
k = k.split(':')
key = k[0]
shape = k[1][1:-1].split(',')
if shape[0] != '':
state_dict[key] = [int(i) for i in shape]
new_module = deepcopy(parent_module)
for n, m in parent_module.named_modules():
old_module = extract_layer(parent_module, n)
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
if isinstance(old_module, Conv2dSame):
conv = Conv2dSame
else:
conv = nn.Conv2d
s = state_dict[n + '.weight']
in_channels = s[1]
out_channels = s[0]
if old_module.groups > 1:
in_channels = out_channels
g = in_channels
else:
g = 1
new_conv = conv(in_channels=in_channels, out_channels=out_channels,
kernel_size=old_module.kernel_size, bias=old_module.bias is not None,
padding=old_module.padding, dilation=old_module.dilation,
groups=g, stride=old_module.stride)
set_layer(new_module, n, new_conv)
if isinstance(old_module, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d(num_features=state_dict[n + '.weight'][0], eps=old_module.eps,
momentum=old_module.momentum,
affine=old_module.affine,
track_running_stats=True)
set_layer(new_module, n, new_bn)
if isinstance(old_module, nn.Linear):
new_fc = nn.Linear(in_features=state_dict[n + '.weight'][1], out_features=old_module.out_features,
bias=old_module.bias is not None)
set_layer(new_module, n, new_fc)
new_module.eval()
parent_module.eval()
return new_module

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save