""" Model creation / weight loading / state_dict helpers Hacked together by / Copyright 2020 Ross Wightman """ import logging import os import math from collections import OrderedDict from copy import deepcopy from typing import Callable import torch import torch.nn as nn from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX try: from torch.hub import get_dir except ImportError: from torch.hub import _get_torch_home as get_dir from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .layers import Conv2dSame, Linear _logger = logging.getLogger(__name__) def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict_key = 'state_dict' if isinstance(checkpoint, dict): if use_ema and 'state_dict_ema' in checkpoint: state_dict_key = 'state_dict_ema' if state_dict_key and state_dict_key in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint[state_dict_key].items(): # strip `module.` prefix name = k[7:] if k.startswith('module') else k new_state_dict[name] = v state_dict = new_state_dict else: state_dict = checkpoint _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) return state_dict else: _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): state_dict = load_state_dict(checkpoint_path, use_ema) model.load_state_dict(state_dict, strict=strict) def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): resume_epoch = None if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if log_info: _logger.info('Restoring model state from checkpoint...') new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): name = k[7:] if k.startswith('module') else k new_state_dict[name] = v model.load_state_dict(new_state_dict) if optimizer is not None and 'optimizer' in checkpoint: if log_info: _logger.info('Restoring optimizer state from checkpoint...') optimizer.load_state_dict(checkpoint['optimizer']) if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: if log_info: _logger.info('Restoring AMP loss scaler state from checkpoint...') loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) if 'epoch' in checkpoint: resume_epoch = checkpoint['epoch'] if 'version' in checkpoint and checkpoint['version'] > 1: resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save if log_info: _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) else: model.load_state_dict(checkpoint) if log_info: _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) return resume_epoch else: _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False): r"""Loads a custom (read non .pth) weight file Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls a passed in custom load fun, or the `load_pretrained` model member fn. If the object is already present in `model_dir`, it's deserialized and returned. The default value of `model_dir` is ``/checkpoints`` where `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. Args: model: The instantiated model to load weights into cfg (dict): Default pretrained model cfg load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named 'laod_pretrained' on the model will be called if it exists progress (bool, optional): whether or not to display a progress bar to stderr. Default: False check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention ``filename-.ext`` where ```` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False """ cfg = cfg or getattr(model, 'default_cfg') if cfg is None or not cfg.get('url', None): _logger.warning("No pretrained weights exist for this model. Using random initialization.") return url = cfg['url'] # Issue warning to move data if old env is set if os.getenv('TORCH_MODEL_ZOO'): _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') hub_dir = get_dir() model_dir = os.path.join(hub_dir, 'checkpoints') os.makedirs(model_dir, exist_ok=True) parts = urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = None if check_hash: r = HASH_REGEX.search(filename) # r is Optional[Match[str]] hash_prefix = r.group(1) if r else None download_url_to_file(url, cached_file, hash_prefix, progress=progress) if load_fn is not None: load_fn(model, cached_file) elif hasattr(model, 'load_pretrained'): model.load_pretrained(cached_file) else: _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") def adapt_input_conv(in_chans, conv_weight): conv_type = conv_weight.dtype conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU O, I, J, K = conv_weight.shape if in_chans == 1: if I > 3: assert conv_weight.shape[1] % 3 == 0 # For models with space2depth stems conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) conv_weight = conv_weight.sum(dim=2, keepdim=False) else: conv_weight = conv_weight.sum(dim=1, keepdim=True) elif in_chans != 3: if I != 3: raise NotImplementedError('Weight format not supported by conversion.') else: # NOTE this strategy should be better than random init, but there could be other combinations of # the original RGB input layer weights that'd work better for specific cases. repeat = int(math.ceil(in_chans / 3)) conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] conv_weight *= (3 / float(in_chans)) conv_weight = conv_weight.to(conv_type) return conv_weight def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): cfg = cfg or getattr(model, 'default_cfg') if cfg is None or not cfg.get('url', None): _logger.warning("No pretrained weights exist for this model. Using random initialization.") return state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu') if filter_fn is not None: state_dict = filter_fn(state_dict) input_convs = cfg.get('first_conv', None) if input_convs is not None and in_chans != 3: if isinstance(input_convs, str): input_convs = (input_convs,) for input_conv_name in input_convs: weight_name = input_conv_name + '.weight' try: state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) _logger.info( f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') except NotImplementedError as e: del state_dict[weight_name] strict = False _logger.warning( f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') classifier_name = cfg['classifier'] label_offset = cfg.get('label_offset', 0) if num_classes != cfg['num_classes']: # completely discard fully connected if model num_classes doesn't match pretrained weights del state_dict[classifier_name + '.weight'] del state_dict[classifier_name + '.bias'] strict = False elif label_offset > 0: # special case for pretrained weights with an extra background class in pretrained weights classifier_weight = state_dict[classifier_name + '.weight'] state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] classifier_bias = state_dict[classifier_name + '.bias'] state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] model.load_state_dict(state_dict, strict=strict) def extract_layer(model, layer): layer = layer.split('.') module = model if hasattr(model, 'module') and layer[0] != 'module': module = model.module if not hasattr(model, 'module') and layer[0] == 'module': layer = layer[1:] for l in layer: if hasattr(module, l): if not l.isdigit(): module = getattr(module, l) else: module = module[int(l)] else: return module return module def set_layer(model, layer, val): layer = layer.split('.') module = model if hasattr(model, 'module') and layer[0] != 'module': module = model.module lst_index = 0 module2 = module for l in layer: if hasattr(module2, l): if not l.isdigit(): module2 = getattr(module2, l) else: module2 = module2[int(l)] lst_index += 1 lst_index -= 1 for l in layer[:lst_index]: if not l.isdigit(): module = getattr(module, l) else: module = module[int(l)] l = layer[lst_index] setattr(module, l, val) def adapt_model_from_string(parent_module, model_string): separator = '***' state_dict = {} lst_shape = model_string.split(separator) for k in lst_shape: k = k.split(':') key = k[0] shape = k[1][1:-1].split(',') if shape[0] != '': state_dict[key] = [int(i) for i in shape] new_module = deepcopy(parent_module) for n, m in parent_module.named_modules(): old_module = extract_layer(parent_module, n) if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): if isinstance(old_module, Conv2dSame): conv = Conv2dSame else: conv = nn.Conv2d s = state_dict[n + '.weight'] in_channels = s[1] out_channels = s[0] g = 1 if old_module.groups > 1: in_channels = out_channels g = in_channels new_conv = conv( in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, groups=g, stride=old_module.stride) set_layer(new_module, n, new_conv) if isinstance(old_module, nn.BatchNorm2d): new_bn = nn.BatchNorm2d( num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, affine=old_module.affine, track_running_stats=True) set_layer(new_module, n, new_bn) if isinstance(old_module, nn.Linear): # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? num_features = state_dict[n + '.weight'][1] new_fc = Linear( in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) set_layer(new_module, n, new_fc) if hasattr(new_module, 'num_features'): new_module.num_features = num_features new_module.eval() parent_module.eval() return new_module def adapt_model_from_file(parent_module, model_variant): adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') with open(adapt_file, 'r') as f: return adapt_model_from_string(parent_module, f.read().strip()) def default_cfg_for_features(default_cfg): default_cfg = deepcopy(default_cfg) # remove default pretrained cfg fields that don't have much relevance for feature backbone to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size? for tr in to_remove: default_cfg.pop(tr, None) return default_cfg def build_model_with_cfg( model_cls: Callable, variant: str, pretrained: bool, default_cfg: dict, model_cfg: dict = None, feature_cfg: dict = None, pretrained_strict: bool = True, pretrained_filter_fn: Callable = None, pretrained_custom_load: bool = False, **kwargs): pruned = kwargs.pop('pruned', False) features = False feature_cfg = feature_cfg or {} if kwargs.pop('features_only', False): features = True feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) if 'out_indices' in kwargs: feature_cfg['out_indices'] = kwargs.pop('out_indices') model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) model.default_cfg = deepcopy(default_cfg) if pruned: model = adapt_model_from_file(model, variant) # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) if pretrained: if pretrained_custom_load: load_custom_pretrained(model) else: load_pretrained( model, num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), filter_fn=pretrained_filter_fn, strict=pretrained_strict) if features: feature_cls = FeatureListNet if 'feature_cls' in feature_cfg: feature_cls = feature_cfg.pop('feature_cls') if isinstance(feature_cls, str): feature_cls = feature_cls.lower() if 'hook' in feature_cls: feature_cls = FeatureHookNet else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg return model def model_parameters(model, exclude_head=False): if exclude_head: # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering return [p for p in model.parameters()][:-2] else: return model.parameters()