You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
311 lines
12 KiB
311 lines
12 KiB
""" Model creation / weight loading / state_dict helpers
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
"""
|
|
import logging
|
|
import os
|
|
import math
|
|
from collections import OrderedDict
|
|
from copy import deepcopy
|
|
from typing import Callable
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.model_zoo as model_zoo
|
|
|
|
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
|
from .layers import Conv2dSame, Linear
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_state_dict(checkpoint_path, use_ema=False):
|
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
state_dict_key = 'state_dict'
|
|
if isinstance(checkpoint, dict):
|
|
if use_ema and 'state_dict_ema' in checkpoint:
|
|
state_dict_key = 'state_dict_ema'
|
|
if state_dict_key and state_dict_key in checkpoint:
|
|
new_state_dict = OrderedDict()
|
|
for k, v in checkpoint[state_dict_key].items():
|
|
# strip `module.` prefix
|
|
name = k[7:] if k.startswith('module') else k
|
|
new_state_dict[name] = v
|
|
state_dict = new_state_dict
|
|
else:
|
|
state_dict = checkpoint
|
|
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
|
return state_dict
|
|
else:
|
|
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
|
raise FileNotFoundError()
|
|
|
|
|
|
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
|
|
state_dict = load_state_dict(checkpoint_path, use_ema)
|
|
model.load_state_dict(state_dict, strict=strict)
|
|
|
|
|
|
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
|
resume_epoch = None
|
|
if os.path.isfile(checkpoint_path):
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
|
if log_info:
|
|
_logger.info('Restoring model state from checkpoint...')
|
|
new_state_dict = OrderedDict()
|
|
for k, v in checkpoint['state_dict'].items():
|
|
name = k[7:] if k.startswith('module') else k
|
|
new_state_dict[name] = v
|
|
model.load_state_dict(new_state_dict)
|
|
|
|
if optimizer is not None and 'optimizer' in checkpoint:
|
|
if log_info:
|
|
_logger.info('Restoring optimizer state from checkpoint...')
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
|
|
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
|
if log_info:
|
|
_logger.info('Restoring AMP loss scaler state from checkpoint...')
|
|
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
|
|
|
if 'epoch' in checkpoint:
|
|
resume_epoch = checkpoint['epoch']
|
|
if 'version' in checkpoint and checkpoint['version'] > 1:
|
|
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
|
|
|
if log_info:
|
|
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
|
else:
|
|
model.load_state_dict(checkpoint)
|
|
if log_info:
|
|
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
|
return resume_epoch
|
|
else:
|
|
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
|
raise FileNotFoundError()
|
|
|
|
|
|
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True):
|
|
if cfg is None:
|
|
cfg = getattr(model, 'default_cfg')
|
|
if cfg is None or 'url' not in cfg or not cfg['url']:
|
|
_logger.warning("Pretrained model URL is invalid, using random initialization.")
|
|
return
|
|
|
|
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
|
|
|
|
if filter_fn is not None:
|
|
state_dict = filter_fn(state_dict)
|
|
|
|
if in_chans == 1:
|
|
conv1_name = cfg['first_conv']
|
|
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
|
|
conv1_weight = state_dict[conv1_name + '.weight']
|
|
# Some weights are in torch.half, ensure it's float for sum on CPU
|
|
conv1_type = conv1_weight.dtype
|
|
conv1_weight = conv1_weight.float()
|
|
O, I, J, K = conv1_weight.shape
|
|
if I > 3:
|
|
assert conv1_weight.shape[1] % 3 == 0
|
|
# For models with space2depth stems
|
|
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
|
|
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
|
|
else:
|
|
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
|
|
conv1_weight = conv1_weight.to(conv1_type)
|
|
state_dict[conv1_name + '.weight'] = conv1_weight
|
|
elif in_chans != 3:
|
|
conv1_name = cfg['first_conv']
|
|
conv1_weight = state_dict[conv1_name + '.weight']
|
|
conv1_type = conv1_weight.dtype
|
|
conv1_weight = conv1_weight.float()
|
|
O, I, J, K = conv1_weight.shape
|
|
if I != 3:
|
|
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
|
|
del state_dict[conv1_name + '.weight']
|
|
strict = False
|
|
else:
|
|
# NOTE this strategy should be better than random init, but there could be other combinations of
|
|
# the original RGB input layer weights that'd work better for specific cases.
|
|
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
|
|
repeat = int(math.ceil(in_chans / 3))
|
|
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
|
conv1_weight *= (3 / float(in_chans))
|
|
conv1_weight = conv1_weight.to(conv1_type)
|
|
state_dict[conv1_name + '.weight'] = conv1_weight
|
|
|
|
classifier_name = cfg['classifier']
|
|
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
|
# special case for imagenet trained models with extra background class in pretrained weights
|
|
classifier_weight = state_dict[classifier_name + '.weight']
|
|
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
|
classifier_bias = state_dict[classifier_name + '.bias']
|
|
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
|
|
elif num_classes != cfg['num_classes']:
|
|
# completely discard fully connected for all other differences between pretrained and created model
|
|
del state_dict[classifier_name + '.weight']
|
|
del state_dict[classifier_name + '.bias']
|
|
strict = False
|
|
|
|
model.load_state_dict(state_dict, strict=strict)
|
|
|
|
|
|
def extract_layer(model, layer):
|
|
layer = layer.split('.')
|
|
module = model
|
|
if hasattr(model, 'module') and layer[0] != 'module':
|
|
module = model.module
|
|
if not hasattr(model, 'module') and layer[0] == 'module':
|
|
layer = layer[1:]
|
|
for l in layer:
|
|
if hasattr(module, l):
|
|
if not l.isdigit():
|
|
module = getattr(module, l)
|
|
else:
|
|
module = module[int(l)]
|
|
else:
|
|
return module
|
|
return module
|
|
|
|
|
|
def set_layer(model, layer, val):
|
|
layer = layer.split('.')
|
|
module = model
|
|
if hasattr(model, 'module') and layer[0] != 'module':
|
|
module = model.module
|
|
lst_index = 0
|
|
module2 = module
|
|
for l in layer:
|
|
if hasattr(module2, l):
|
|
if not l.isdigit():
|
|
module2 = getattr(module2, l)
|
|
else:
|
|
module2 = module2[int(l)]
|
|
lst_index += 1
|
|
lst_index -= 1
|
|
for l in layer[:lst_index]:
|
|
if not l.isdigit():
|
|
module = getattr(module, l)
|
|
else:
|
|
module = module[int(l)]
|
|
l = layer[lst_index]
|
|
setattr(module, l, val)
|
|
|
|
|
|
def adapt_model_from_string(parent_module, model_string):
|
|
separator = '***'
|
|
state_dict = {}
|
|
lst_shape = model_string.split(separator)
|
|
for k in lst_shape:
|
|
k = k.split(':')
|
|
key = k[0]
|
|
shape = k[1][1:-1].split(',')
|
|
if shape[0] != '':
|
|
state_dict[key] = [int(i) for i in shape]
|
|
|
|
new_module = deepcopy(parent_module)
|
|
for n, m in parent_module.named_modules():
|
|
old_module = extract_layer(parent_module, n)
|
|
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
|
if isinstance(old_module, Conv2dSame):
|
|
conv = Conv2dSame
|
|
else:
|
|
conv = nn.Conv2d
|
|
s = state_dict[n + '.weight']
|
|
in_channels = s[1]
|
|
out_channels = s[0]
|
|
g = 1
|
|
if old_module.groups > 1:
|
|
in_channels = out_channels
|
|
g = in_channels
|
|
new_conv = conv(
|
|
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
|
|
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
|
groups=g, stride=old_module.stride)
|
|
set_layer(new_module, n, new_conv)
|
|
if isinstance(old_module, nn.BatchNorm2d):
|
|
new_bn = nn.BatchNorm2d(
|
|
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
|
affine=old_module.affine, track_running_stats=True)
|
|
set_layer(new_module, n, new_bn)
|
|
if isinstance(old_module, nn.Linear):
|
|
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
|
|
num_features = state_dict[n + '.weight'][1]
|
|
new_fc = Linear(
|
|
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
|
|
set_layer(new_module, n, new_fc)
|
|
if hasattr(new_module, 'num_features'):
|
|
new_module.num_features = num_features
|
|
new_module.eval()
|
|
parent_module.eval()
|
|
|
|
return new_module
|
|
|
|
|
|
def adapt_model_from_file(parent_module, model_variant):
|
|
adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
|
|
with open(adapt_file, 'r') as f:
|
|
return adapt_model_from_string(parent_module, f.read().strip())
|
|
|
|
|
|
def default_cfg_for_features(default_cfg):
|
|
default_cfg = deepcopy(default_cfg)
|
|
# remove default pretrained cfg fields that don't have much relevance for feature backbone
|
|
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
|
|
for tr in to_remove:
|
|
default_cfg.pop(tr, None)
|
|
return default_cfg
|
|
|
|
|
|
def build_model_with_cfg(
|
|
model_cls: Callable,
|
|
variant: str,
|
|
pretrained: bool,
|
|
default_cfg: dict,
|
|
model_cfg: dict = None,
|
|
feature_cfg: dict = None,
|
|
pretrained_strict: bool = True,
|
|
pretrained_filter_fn: Callable = None,
|
|
**kwargs):
|
|
pruned = kwargs.pop('pruned', False)
|
|
features = False
|
|
feature_cfg = feature_cfg or {}
|
|
|
|
if kwargs.pop('features_only', False):
|
|
features = True
|
|
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
|
if 'out_indices' in kwargs:
|
|
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
|
|
|
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
|
|
model.default_cfg = deepcopy(default_cfg)
|
|
|
|
if pruned:
|
|
model = adapt_model_from_file(model, variant)
|
|
|
|
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
|
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
|
if pretrained:
|
|
load_pretrained(
|
|
model,
|
|
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
|
|
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
|
|
|
if features:
|
|
feature_cls = FeatureListNet
|
|
if 'feature_cls' in feature_cfg:
|
|
feature_cls = feature_cfg.pop('feature_cls')
|
|
if isinstance(feature_cls, str):
|
|
feature_cls = feature_cls.lower()
|
|
if 'hook' in feature_cls:
|
|
feature_cls = FeatureHookNet
|
|
else:
|
|
assert False, f'Unknown feature class {feature_cls}'
|
|
model = feature_cls(model, **feature_cfg)
|
|
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
|
|
|
|
return model
|