You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/helpers.py

375 lines
15 KiB

""" Model creation / weight loading / state_dict helpers
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import os
import math
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
import torch
4 years ago
import torch.nn as nn
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
try:
from torch.hub import get_dir
except ImportError:
from torch.hub import _get_torch_home as get_dir
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame, Linear
_logger = logging.getLogger(__name__)
def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = 'state_dict'
if isinstance(checkpoint, dict):
if use_ema and 'state_dict_ema' in checkpoint:
state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
# strip `module.` prefix
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
state_dict = new_state_dict
else:
state_dict = checkpoint
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
return state_dict
else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=strict)
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
if log_info:
_logger.info('Restoring model state from checkpoint...')
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
if optimizer is not None and 'optimizer' in checkpoint:
if log_info:
_logger.info('Restoring optimizer state from checkpoint...')
optimizer.load_state_dict(checkpoint['optimizer'])
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
if log_info:
_logger.info('Restoring AMP loss scaler state from checkpoint...')
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
model.load_state_dict(checkpoint)
if log_info:
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
return resume_epoch
else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False):
r"""Loads a custom (read non .pth) weight file
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
a passed in custom load fun, or the `load_pretrained` model member fn.
If the object is already present in `model_dir`, it's deserialized and returned.
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
Args:
model: The instantiated model to load weights into
cfg (dict): Default pretrained model cfg
load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists
progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file. Default: False
"""
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL does not exist, using random initialization.")
return
url = cfg['url']
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
if load_fn is not None:
load_fn(model, cached_file)
elif hasattr(model, 'load_pretrained'):
model.load_pretrained(cached_file)
else:
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL does not exist, using random initialization.")
return
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
if filter_fn is not None:
state_dict = filter_fn(state_dict)
if in_chans == 1:
conv1_name = cfg['first_conv']
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
conv1_weight = state_dict[conv1_name + '.weight']
# Some weights are in torch.half, ensure it's float for sum on CPU
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I > 3:
assert conv1_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
else:
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
elif in_chans != 3:
conv1_name = cfg['first_conv']
conv1_weight = state_dict[conv1_name + '.weight']
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I != 3:
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
del state_dict[conv1_name + '.weight']
strict = False
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
repeat = int(math.ceil(in_chans / 3))
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv1_weight *= (3 / float(in_chans))
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
classifier_name = cfg['classifier']
if num_classes == 1000 and cfg['num_classes'] == 1001:
# FIXME this special case is problematic as number of pretrained weight sources increases
# special case for imagenet trained models with extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
elif num_classes != cfg['num_classes']:
# completely discard fully connected for all other differences between pretrained and created model
del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias']
strict = False
model.load_state_dict(state_dict, strict=strict)
4 years ago
def extract_layer(model, layer):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
if not hasattr(model, 'module') and layer[0] == 'module':
layer = layer[1:]
for l in layer:
if hasattr(module, l):
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
else:
return module
return module
def set_layer(model, layer, val):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
lst_index = 0
module2 = module
for l in layer:
if hasattr(module2, l):
if not l.isdigit():
module2 = getattr(module2, l)
else:
module2 = module2[int(l)]
lst_index += 1
lst_index -= 1
for l in layer[:lst_index]:
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
l = layer[lst_index]
setattr(module, l, val)
def adapt_model_from_string(parent_module, model_string):
separator = '***'
state_dict = {}
lst_shape = model_string.split(separator)
for k in lst_shape:
k = k.split(':')
key = k[0]
shape = k[1][1:-1].split(',')
if shape[0] != '':
state_dict[key] = [int(i) for i in shape]
new_module = deepcopy(parent_module)
for n, m in parent_module.named_modules():
old_module = extract_layer(parent_module, n)
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
if isinstance(old_module, Conv2dSame):
conv = Conv2dSame
else:
conv = nn.Conv2d
s = state_dict[n + '.weight']
in_channels = s[1]
out_channels = s[0]
g = 1
4 years ago
if old_module.groups > 1:
in_channels = out_channels
g = in_channels
new_conv = conv(
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
groups=g, stride=old_module.stride)
4 years ago
set_layer(new_module, n, new_conv)
if isinstance(old_module, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d(
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
4 years ago
set_layer(new_module, n, new_bn)
if isinstance(old_module, nn.Linear):
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
num_features = state_dict[n + '.weight'][1]
new_fc = Linear(
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
4 years ago
set_layer(new_module, n, new_fc)
if hasattr(new_module, 'num_features'):
new_module.num_features = num_features
4 years ago
new_module.eval()
parent_module.eval()
return new_module
def adapt_model_from_file(parent_module, model_variant):
adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
with open(adapt_file, 'r') as f:
return adapt_model_from_string(parent_module, f.read().strip())
def default_cfg_for_features(default_cfg):
default_cfg = deepcopy(default_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
for tr in to_remove:
default_cfg.pop(tr, None)
return default_cfg
def build_model_with_cfg(
model_cls: Callable,
variant: str,
pretrained: bool,
default_cfg: dict,
model_cfg: dict = None,
feature_cfg: dict = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Callable = None,
pretrained_custom_load: bool = False,
**kwargs):
pruned = kwargs.pop('pruned', False)
features = False
feature_cfg = feature_cfg or {}
if kwargs.pop('features_only', False):
features = True
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
model.default_cfg = deepcopy(default_cfg)
if pruned:
model = adapt_model_from_file(model, variant)
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_custom_load:
load_custom_pretrained(model)
else:
load_pretrained(
model,
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
if features:
feature_cls = FeatureListNet
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model