|
|
|
@ -1,8 +1,11 @@
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
import torch.utils.model_zoo as model_zoo
|
|
|
|
|
import os
|
|
|
|
|
import logging
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from timm.models.layers.conv2d_same import Conv2dSame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_state_dict(checkpoint_path, use_ema=False):
|
|
|
|
@ -101,4 +104,91 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_layer(model, layer):
|
|
|
|
|
layer = layer.split('.')
|
|
|
|
|
module = model
|
|
|
|
|
if hasattr(model, 'module') and layer[0] != 'module':
|
|
|
|
|
module = model.module
|
|
|
|
|
if not hasattr(model, 'module') and layer[0] == 'module':
|
|
|
|
|
layer = layer[1:]
|
|
|
|
|
for l in layer:
|
|
|
|
|
if hasattr(module, l):
|
|
|
|
|
if not l.isdigit():
|
|
|
|
|
module = getattr(module, l)
|
|
|
|
|
else:
|
|
|
|
|
module = module[int(l)]
|
|
|
|
|
else:
|
|
|
|
|
return module
|
|
|
|
|
return module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_layer(model, layer, val):
|
|
|
|
|
layer = layer.split('.')
|
|
|
|
|
module = model
|
|
|
|
|
if hasattr(model, 'module') and layer[0] != 'module':
|
|
|
|
|
module = model.module
|
|
|
|
|
lst_index = 0
|
|
|
|
|
module2 = module
|
|
|
|
|
for l in layer:
|
|
|
|
|
if hasattr(module2, l):
|
|
|
|
|
if not l.isdigit():
|
|
|
|
|
module2 = getattr(module2, l)
|
|
|
|
|
else:
|
|
|
|
|
module2 = module2[int(l)]
|
|
|
|
|
lst_index += 1
|
|
|
|
|
lst_index -= 1
|
|
|
|
|
for l in layer[:lst_index]:
|
|
|
|
|
if not l.isdigit():
|
|
|
|
|
module = getattr(module, l)
|
|
|
|
|
else:
|
|
|
|
|
module = module[int(l)]
|
|
|
|
|
l = layer[lst_index]
|
|
|
|
|
setattr(module, l, val)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def adapt_model_from_string(parent_module, model_string):
|
|
|
|
|
separator = '***'
|
|
|
|
|
state_dict = {}
|
|
|
|
|
lst_shape = model_string.split(separator)
|
|
|
|
|
for k in lst_shape:
|
|
|
|
|
k = k.split(':')
|
|
|
|
|
key = k[0]
|
|
|
|
|
shape = k[1][1:-1].split(',')
|
|
|
|
|
if shape[0] != '':
|
|
|
|
|
state_dict[key] = [int(i) for i in shape]
|
|
|
|
|
|
|
|
|
|
new_module = deepcopy(parent_module)
|
|
|
|
|
for n, m in parent_module.named_modules():
|
|
|
|
|
old_module = extract_layer(parent_module, n)
|
|
|
|
|
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
|
|
|
|
if isinstance(old_module, Conv2dSame):
|
|
|
|
|
conv = Conv2dSame
|
|
|
|
|
else:
|
|
|
|
|
conv = nn.Conv2d
|
|
|
|
|
s = state_dict[n + '.weight']
|
|
|
|
|
in_channels = s[1]
|
|
|
|
|
out_channels = s[0]
|
|
|
|
|
if old_module.groups > 1:
|
|
|
|
|
in_channels = out_channels
|
|
|
|
|
g = in_channels
|
|
|
|
|
else:
|
|
|
|
|
g = 1
|
|
|
|
|
new_conv = conv(in_channels=in_channels, out_channels=out_channels,
|
|
|
|
|
kernel_size=old_module.kernel_size, bias=old_module.bias is not None,
|
|
|
|
|
padding=old_module.padding, dilation=old_module.dilation,
|
|
|
|
|
groups=g, stride=old_module.stride)
|
|
|
|
|
set_layer(new_module, n, new_conv)
|
|
|
|
|
if isinstance(old_module, nn.BatchNorm2d):
|
|
|
|
|
new_bn = nn.BatchNorm2d(num_features=state_dict[n + '.weight'][0], eps=old_module.eps,
|
|
|
|
|
momentum=old_module.momentum,
|
|
|
|
|
affine=old_module.affine,
|
|
|
|
|
track_running_stats=True)
|
|
|
|
|
set_layer(new_module, n, new_bn)
|
|
|
|
|
if isinstance(old_module, nn.Linear):
|
|
|
|
|
new_fc = nn.Linear(in_features=state_dict[n + '.weight'][1], out_features=old_module.out_features,
|
|
|
|
|
bias=old_module.bias is not None)
|
|
|
|
|
set_layer(new_module, n, new_fc)
|
|
|
|
|
new_module.eval()
|
|
|
|
|
parent_module.eval()
|
|
|
|
|
|
|
|
|
|
return new_module
|
|
|
|
|