@ -7,17 +7,14 @@ import os
import math
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
from typing import Any, Callable, Optional , Tuple
import torch
import torch . nn as nn
from torch . hub import load_state_dict_from_url , download_url_to_file , urlparse , HASH_REGEX
try :
from torch . hub import get_dir
except ImportError :
from torch . hub import _get_torch_home as get_dir
from . features import FeatureListNet , FeatureDictNet , FeatureHookNet
from . hub import has_hf_hub , download_cached_file , load_state_dict_from_hf , load_state_dict_from_url
from . layers import Conv2dSame , Linear
@ -92,7 +89,7 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None,
raise FileNotFoundError ( )
def load_custom_pretrained ( model , cfg= None , load_fn = None , progress = False , check_hash = False ) :
def load_custom_pretrained ( model , default_ cfg= None , load_fn = None , progress = False , check_hash = False ) :
r """ Loads a custom (read non .pth) weight file
Downloads checkpoint file into cache - dir like torch . hub based loaders , but calls
@ -104,7 +101,7 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
Args :
model : The instantiated model to load weights into
cfg ( dict ) : Default pretrained model cfg
default_ cfg ( dict ) : Default pretrained model cfg
load_fn : An external stand alone fn that loads weights into provided model , otherwise a fn named
' laod_pretrained ' on the model will be called if it exists
progress ( bool , optional ) : whether or not to display a progress bar to stderr . Default : False
@ -113,31 +110,12 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
digits of the SHA256 hash of the contents of the file . The hash is used to
ensure unique names and to verify the contents of the file . Default : False
"""
cfg = cfg or getattr ( model , ' default_cfg ' )
if cfg is None or not cfg . get ( ' url ' , None ) :
default_cfg = default_cfg or getattr ( model , ' default_cfg ' , None ) or { }
pretrained_url = default_cfg . get ( ' url ' , None )
if not pretrained_url :
_logger . warning ( " No pretrained weights exist for this model. Using random initialization. " )
return
url = cfg [ ' url ' ]
# Issue warning to move data if old env is set
if os . getenv ( ' TORCH_MODEL_ZOO ' ) :
_logger . warning ( ' TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead ' )
hub_dir = get_dir ( )
model_dir = os . path . join ( hub_dir , ' checkpoints ' )
os . makedirs ( model_dir , exist_ok = True )
parts = urlparse ( url )
filename = os . path . basename ( parts . path )
cached_file = os . path . join ( model_dir , filename )
if not os . path . exists ( cached_file ) :
_logger . info ( ' Downloading: " {} " to {} \n ' . format ( url , cached_file ) )
hash_prefix = None
if check_hash :
r = HASH_REGEX . search ( filename ) # r is Optional[Match[str]]
hash_prefix = r . group ( 1 ) if r else None
download_url_to_file ( url , cached_file , hash_prefix , progress = progress )
cached_file = download_cached_file ( default_cfg [ ' url ' ] , check_hash = check_hash , progress = progress )
if load_fn is not None :
load_fn ( model , cached_file )
@ -172,17 +150,39 @@ def adapt_input_conv(in_chans, conv_weight):
return conv_weight
def load_pretrained ( model , cfg = None , num_classes = 1000 , in_chans = 3 , filter_fn = None , strict = True , progress = False ) :
cfg = cfg or getattr ( model , ' default_cfg ' )
if cfg is None or not cfg . get ( ' url ' , None ) :
def load_pretrained ( model , default_cfg = None , num_classes = 1000 , in_chans = 3 , filter_fn = None , strict = True , progress = False ) :
""" Load pretrained checkpoint
Args :
model ( nn . Module ) : PyTorch model module
default_cfg ( Optional [ Dict ] ) : default configuration for pretrained weights / target dataset
num_classes ( int ) : num_classes for model
in_chans ( int ) : in_chans for model
filter_fn ( Optional [ Callable ] ) : state_dict filter fn for load ( takes state_dict , model as args )
strict ( bool ) : strict load of checkpoint
progress ( bool ) : enable progress bar for weight download
"""
default_cfg = default_cfg or getattr ( model , ' default_cfg ' , None ) or { }
pretrained_url = default_cfg . get ( ' url ' , None )
hf_hub_id = default_cfg . get ( ' hf_hub ' , None )
if not pretrained_url and not hf_hub_id :
_logger . warning ( " No pretrained weights exist for this model. Using random initialization. " )
return
state_dict = load_state_dict_from_url ( cfg [ ' url ' ] , progress = progress , map_location = ' cpu ' )
if hf_hub_id and has_hf_hub ( necessary = not pretrained_url ) :
_logger . info ( f ' Loading pretrained weights from huggingface hub ( { hf_hub_id } ) ' )
state_dict = load_state_dict_from_hf ( hf_hub_id )
else :
_logger . info ( f ' Loading pretrained weights from url ( { pretrained_url } ) ' )
state_dict = load_state_dict_from_url ( pretrained_url , progress = progress , map_location = ' cpu ' )
if filter_fn is not None :
state_dict = filter_fn ( state_dict )
# for backwards compat with filter fn that take one arg, try one first, the two
try :
state_dict = filter_fn ( state_dict )
except TypeError :
state_dict = filter_fn ( state_dict , model )
input_convs = cfg . get ( ' first_conv ' , None )
input_convs = default_ cfg. get ( ' first_conv ' , None )
if input_convs is not None and in_chans != 3 :
if isinstance ( input_convs , str ) :
input_convs = ( input_convs , )
@ -198,19 +198,20 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
_logger . warning (
f ' Unable to convert pretrained { input_conv_name } weights, using random init for this layer. ' )
classifier_name = cfg [ ' classifier ' ]
label_offset = cfg . get ( ' label_offset ' , 0 )
if num_classes != cfg [ ' num_classes ' ] :
# completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict [ classifier_name + ' .weight ' ]
del state_dict [ classifier_name + ' .bias ' ]
strict = False
elif label_offset > 0 :
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict [ classifier_name + ' .weight ' ]
state_dict [ classifier_name + ' .weight ' ] = classifier_weight [ label_offset : ]
classifier_bias = state_dict [ classifier_name + ' .bias ' ]
state_dict [ classifier_name + ' .bias ' ] = classifier_bias [ label_offset : ]
classifier_name = default_cfg . get ( ' classifier ' , None )
label_offset = default_cfg . get ( ' label_offset ' , 0 )
if classifier_name is not None :
if num_classes != default_cfg [ ' num_classes ' ] :
# completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict [ classifier_name + ' .weight ' ]
del state_dict [ classifier_name + ' .bias ' ]
strict = False
elif label_offset > 0 :
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict [ classifier_name + ' .weight ' ]
state_dict [ classifier_name + ' .weight ' ] = classifier_weight [ label_offset : ]
classifier_bias = state_dict [ classifier_name + ' .bias ' ]
state_dict [ classifier_name + ' .bias ' ] = classifier_bias [ label_offset : ]
model . load_state_dict ( state_dict , strict = strict )
@ -316,40 +317,116 @@ def adapt_model_from_file(parent_module, model_variant):
def default_cfg_for_features ( default_cfg ) :
default_cfg = deepcopy ( default_cfg )
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ( ' num_classes ' , ' crop_pct ' , ' classifier ' ) # add default final pool size?
to_remove = ( ' num_classes ' , ' crop_pct ' , ' classifier ' , ' global_pool ' ) # add default final pool size?
for tr in to_remove :
default_cfg . pop ( tr , None )
return default_cfg
def overlay_external_default_cfg ( kwargs , default_cfg ) :
""" Overlay ' default_cfg ' in kwargs on top of default_cfg arg.
"""
default_cfg = default_cfg or { }
external_default_cfg = kwargs . pop ( ' external_default_cfg ' , None )
if external_default_cfg :
default_cfg = deepcopy ( default_cfg )
default_cfg . pop ( ' url ' , None ) # url should come from external cfg
default_cfg . pop ( ' hf_hub ' , None ) # hf hub id should come from external cfg
default_cfg . update ( external_default_cfg )
return default_cfg
def set_default_kwargs ( kwargs , names , default_cfg ) :
for n in names :
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# default_cfg has one input_size=(C, H ,W) entry
if n == ' img_size ' :
input_size = default_cfg . get ( ' input_size ' , None )
if input_size is not None :
assert len ( input_size ) == 3
kwargs . setdefault ( n , input_size [ : - 2 ] )
elif n == ' in_chans ' :
input_size = default_cfg . get ( ' input_size ' , None )
if input_size is not None :
assert len ( input_size ) == 3
kwargs . setdefault ( n , input_size [ 0 ] )
else :
default_val = default_cfg . get ( n , None )
if default_val is not None :
kwargs . setdefault ( n , default_cfg [ n ] )
def filter_kwargs ( kwargs , names ) :
if not kwargs or not names :
return
for n in names :
kwargs . pop ( n , None )
def build_model_with_cfg (
model_cls : Callable ,
variant : str ,
pretrained : bool ,
default_cfg : dict ,
model_cfg : dict = None ,
feature_cfg : dict = None ,
model_cfg : Optional [ Any ] = None ,
feature_cfg : Optional [ dict ] = None ,
pretrained_strict : bool = True ,
pretrained_filter_fn : Callable = None ,
pretrained_filter_fn : Optional[ Callable] = None ,
pretrained_custom_load : bool = False ,
kwargs_filter : Optional [ Tuple [ str ] ] = None ,
* * kwargs ) :
""" Build model with specified default_cfg and optional model_cfg
This helper fn aids in the construction of a model including :
* handling default_cfg and associated pretained weight loading
* passing through optional model_cfg for models with config based arch spec
* features_only model adaptation
* pruning config / model adaptation
Args :
model_cls ( nn . Module ) : model class
variant ( str ) : model variant name
pretrained ( bool ) : load pretrained weights
default_cfg ( dict ) : model ' s default pretrained/task config
model_cfg ( Optional [ Dict ] ) : model ' s architecture config
feature_cfg ( Optional [ Dict ] : feature extraction adapter config
pretrained_strict ( bool ) : load pretrained weights strictly
pretrained_filter_fn ( Optional [ Callable ] ) : filter callable for pretrained weights
pretrained_custom_load ( bool ) : use custom load fn , to load numpy or other non PyTorch weights
kwargs_filter ( Optional [ Tuple ] ) : kwargs to filter before passing to model
* * kwargs : model args passed through to model __init__
"""
pruned = kwargs . pop ( ' pruned ' , False )
features = False
feature_cfg = feature_cfg or { }
# Setup for featyre extraction wrapper done at end of this fn
if kwargs . pop ( ' features_only ' , False ) :
features = True
feature_cfg . setdefault ( ' out_indices ' , ( 0 , 1 , 2 , 3 , 4 ) )
if ' out_indices ' in kwargs :
feature_cfg [ ' out_indices ' ] = kwargs . pop ( ' out_indices ' )
# FIXME this next sequence of overlay default_cfg, set default kwargs, filter kwargs
# could/should be replaced by an improved configuration mechanism
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
default_cfg = overlay_external_default_cfg ( kwargs , default_cfg )
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
set_default_kwargs ( kwargs , names = ( ' num_classes ' , ' global_pool ' , ' in_chans ' ) , default_cfg = default_cfg )
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs ( kwargs , names = kwargs_filter )
# Build the model
model = model_cls ( * * kwargs ) if model_cfg is None else model_cls ( cfg = model_cfg , * * kwargs )
model . default_cfg = deepcopy ( default_cfg )
if pruned :
model = adapt_model_from_file ( model , variant )
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
# F or classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr ( model , ' num_classes ' , kwargs . get ( ' num_classes ' , 1000 ) )
if pretrained :
if pretrained_custom_load :
@ -357,9 +434,12 @@ def build_model_with_cfg(
else :
load_pretrained (
model ,
num_classes = num_classes_pretrained , in_chans = kwargs . get ( ' in_chans ' , 3 ) ,
filter_fn = pretrained_filter_fn , strict = pretrained_strict )
num_classes = num_classes_pretrained ,
in_chans = kwargs . get ( ' in_chans ' , 3 ) ,
filter_fn = pretrained_filter_fn ,
strict = pretrained_strict )
# Wrap the model in a feature extraction module if enabled
if features :
feature_cls = FeatureListNet
if ' feature_cls ' in feature_cfg :