@ -7,17 +7,14 @@ import os
import math
import math
from collections import OrderedDict
from collections import OrderedDict
from copy import deepcopy
from copy import deepcopy
from typing import Callable
from typing import Any, Callable, Optional , Tuple
import torch
import torch
import torch . nn as nn
import torch . nn as nn
from torch . hub import load_state_dict_from_url , download_url_to_file , urlparse , HASH_REGEX
try :
from torch . hub import get_dir
except ImportError :
from torch . hub import _get_torch_home as get_dir
from . features import FeatureListNet , FeatureDictNet , FeatureHookNet
from . features import FeatureListNet , FeatureDictNet , FeatureHookNet
from . hub import has_hf_hub , download_cached_file , load_state_dict_from_hf , load_state_dict_from_url
from . layers import Conv2dSame , Linear
from . layers import Conv2dSame , Linear
@ -92,7 +89,7 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None,
raise FileNotFoundError ( )
raise FileNotFoundError ( )
def load_custom_pretrained ( model , cfg= None , load_fn = None , progress = False , check_hash = False ) :
def load_custom_pretrained ( model , default_ cfg= None , load_fn = None , progress = False , check_hash = False ) :
r """ Loads a custom (read non .pth) weight file
r """ Loads a custom (read non .pth) weight file
Downloads checkpoint file into cache - dir like torch . hub based loaders , but calls
Downloads checkpoint file into cache - dir like torch . hub based loaders , but calls
@ -104,7 +101,7 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
Args :
Args :
model : The instantiated model to load weights into
model : The instantiated model to load weights into
cfg ( dict ) : Default pretrained model cfg
default_ cfg ( dict ) : Default pretrained model cfg
load_fn : An external stand alone fn that loads weights into provided model , otherwise a fn named
load_fn : An external stand alone fn that loads weights into provided model , otherwise a fn named
' laod_pretrained ' on the model will be called if it exists
' laod_pretrained ' on the model will be called if it exists
progress ( bool , optional ) : whether or not to display a progress bar to stderr . Default : False
progress ( bool , optional ) : whether or not to display a progress bar to stderr . Default : False
@ -113,31 +110,12 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
digits of the SHA256 hash of the contents of the file . The hash is used to
digits of the SHA256 hash of the contents of the file . The hash is used to
ensure unique names and to verify the contents of the file . Default : False
ensure unique names and to verify the contents of the file . Default : False
"""
"""
cfg = cfg or getattr ( model , ' default_cfg ' )
default_cfg = default_cfg or getattr ( model , ' default_cfg ' , None ) or { }
if cfg is None or not cfg . get ( ' url ' , None ) :
pretrained_url = default_cfg . get ( ' url ' , None )
if not pretrained_url :
_logger . warning ( " No pretrained weights exist for this model. Using random initialization. " )
_logger . warning ( " No pretrained weights exist for this model. Using random initialization. " )
return
return
url = cfg [ ' url ' ]
cached_file = download_cached_file ( default_cfg [ ' url ' ] , check_hash = check_hash , progress = progress )
# Issue warning to move data if old env is set
if os . getenv ( ' TORCH_MODEL_ZOO ' ) :
_logger . warning ( ' TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead ' )
hub_dir = get_dir ( )
model_dir = os . path . join ( hub_dir , ' checkpoints ' )
os . makedirs ( model_dir , exist_ok = True )
parts = urlparse ( url )
filename = os . path . basename ( parts . path )
cached_file = os . path . join ( model_dir , filename )
if not os . path . exists ( cached_file ) :
_logger . info ( ' Downloading: " {} " to {} \n ' . format ( url , cached_file ) )
hash_prefix = None
if check_hash :
r = HASH_REGEX . search ( filename ) # r is Optional[Match[str]]
hash_prefix = r . group ( 1 ) if r else None
download_url_to_file ( url , cached_file , hash_prefix , progress = progress )
if load_fn is not None :
if load_fn is not None :
load_fn ( model , cached_file )
load_fn ( model , cached_file )
@ -172,17 +150,39 @@ def adapt_input_conv(in_chans, conv_weight):
return conv_weight
return conv_weight
def load_pretrained ( model , cfg = None , num_classes = 1000 , in_chans = 3 , filter_fn = None , strict = True , progress = False ) :
def load_pretrained ( model , default_cfg = None , num_classes = 1000 , in_chans = 3 , filter_fn = None , strict = True , progress = False ) :
cfg = cfg or getattr ( model , ' default_cfg ' )
""" Load pretrained checkpoint
if cfg is None or not cfg . get ( ' url ' , None ) :
Args :
model ( nn . Module ) : PyTorch model module
default_cfg ( Optional [ Dict ] ) : default configuration for pretrained weights / target dataset
num_classes ( int ) : num_classes for model
in_chans ( int ) : in_chans for model
filter_fn ( Optional [ Callable ] ) : state_dict filter fn for load ( takes state_dict , model as args )
strict ( bool ) : strict load of checkpoint
progress ( bool ) : enable progress bar for weight download
"""
default_cfg = default_cfg or getattr ( model , ' default_cfg ' , None ) or { }
pretrained_url = default_cfg . get ( ' url ' , None )
hf_hub_id = default_cfg . get ( ' hf_hub ' , None )
if not pretrained_url and not hf_hub_id :
_logger . warning ( " No pretrained weights exist for this model. Using random initialization. " )
_logger . warning ( " No pretrained weights exist for this model. Using random initialization. " )
return
return
if hf_hub_id and has_hf_hub ( necessary = not pretrained_url ) :
state_dict = load_state_dict_from_url ( cfg [ ' url ' ] , progress = progress , map_location = ' cpu ' )
_logger . info ( f ' Loading pretrained weights from huggingface hub ( { hf_hub_id } ) ' )
state_dict = load_state_dict_from_hf ( hf_hub_id )
else :
_logger . info ( f ' Loading pretrained weights from url ( { pretrained_url } ) ' )
state_dict = load_state_dict_from_url ( pretrained_url , progress = progress , map_location = ' cpu ' )
if filter_fn is not None :
if filter_fn is not None :
state_dict = filter_fn ( state_dict )
# for backwards compat with filter fn that take one arg, try one first, the two
try :
state_dict = filter_fn ( state_dict )
except TypeError :
state_dict = filter_fn ( state_dict , model )
input_convs = cfg . get ( ' first_conv ' , None )
input_convs = default_ cfg. get ( ' first_conv ' , None )
if input_convs is not None and in_chans != 3 :
if input_convs is not None and in_chans != 3 :
if isinstance ( input_convs , str ) :
if isinstance ( input_convs , str ) :
input_convs = ( input_convs , )
input_convs = ( input_convs , )
@ -198,19 +198,20 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
_logger . warning (
_logger . warning (
f ' Unable to convert pretrained { input_conv_name } weights, using random init for this layer. ' )
f ' Unable to convert pretrained { input_conv_name } weights, using random init for this layer. ' )
classifier_name = cfg [ ' classifier ' ]
classifier_name = default_cfg . get ( ' classifier ' , None )
label_offset = cfg . get ( ' label_offset ' , 0 )
label_offset = default_cfg . get ( ' label_offset ' , 0 )
if num_classes != cfg [ ' num_classes ' ] :
if classifier_name is not None :
# completely discard fully connected if model num_classes doesn't match pretrained weights
if num_classes != default_cfg [ ' num_classes ' ] :
del state_dict [ classifier_name + ' .weight ' ]
# completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict [ classifier_name + ' .bias ' ]
del state_dict [ classifier_name + ' .weight ' ]
strict = False
del state_dict [ classifier_name + ' .bias ' ]
elif label_offset > 0 :
strict = False
# special case for pretrained weights with an extra background class in pretrained weights
elif label_offset > 0 :
classifier_weight = state_dict [ classifier_name + ' .weight ' ]
# special case for pretrained weights with an extra background class in pretrained weights
state_dict [ classifier_name + ' .weight ' ] = classifier_weight [ label_offset : ]
classifier_weight = state_dict [ classifier_name + ' .weight ' ]
classifier_bias = state_dict [ classifier_name + ' .bias ' ]
state_dict [ classifier_name + ' .weight ' ] = classifier_weight [ label_offset : ]
state_dict [ classifier_name + ' .bias ' ] = classifier_bias [ label_offset : ]
classifier_bias = state_dict [ classifier_name + ' .bias ' ]
state_dict [ classifier_name + ' .bias ' ] = classifier_bias [ label_offset : ]
model . load_state_dict ( state_dict , strict = strict )
model . load_state_dict ( state_dict , strict = strict )
@ -316,40 +317,116 @@ def adapt_model_from_file(parent_module, model_variant):
def default_cfg_for_features ( default_cfg ) :
def default_cfg_for_features ( default_cfg ) :
default_cfg = deepcopy ( default_cfg )
default_cfg = deepcopy ( default_cfg )
# remove default pretrained cfg fields that don't have much relevance for feature backbone
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ( ' num_classes ' , ' crop_pct ' , ' classifier ' ) # add default final pool size?
to_remove = ( ' num_classes ' , ' crop_pct ' , ' classifier ' , ' global_pool ' ) # add default final pool size?
for tr in to_remove :
for tr in to_remove :
default_cfg . pop ( tr , None )
default_cfg . pop ( tr , None )
return default_cfg
return default_cfg
def overlay_external_default_cfg ( kwargs , default_cfg ) :
""" Overlay ' default_cfg ' in kwargs on top of default_cfg arg.
"""
default_cfg = default_cfg or { }
external_default_cfg = kwargs . pop ( ' external_default_cfg ' , None )
if external_default_cfg :
default_cfg = deepcopy ( default_cfg )
default_cfg . pop ( ' url ' , None ) # url should come from external cfg
default_cfg . pop ( ' hf_hub ' , None ) # hf hub id should come from external cfg
default_cfg . update ( external_default_cfg )
return default_cfg
def set_default_kwargs ( kwargs , names , default_cfg ) :
for n in names :
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# default_cfg has one input_size=(C, H ,W) entry
if n == ' img_size ' :
input_size = default_cfg . get ( ' input_size ' , None )
if input_size is not None :
assert len ( input_size ) == 3
kwargs . setdefault ( n , input_size [ : - 2 ] )
elif n == ' in_chans ' :
input_size = default_cfg . get ( ' input_size ' , None )
if input_size is not None :
assert len ( input_size ) == 3
kwargs . setdefault ( n , input_size [ 0 ] )
else :
default_val = default_cfg . get ( n , None )
if default_val is not None :
kwargs . setdefault ( n , default_cfg [ n ] )
def filter_kwargs ( kwargs , names ) :
if not kwargs or not names :
return
for n in names :
kwargs . pop ( n , None )
def build_model_with_cfg (
def build_model_with_cfg (
model_cls : Callable ,
model_cls : Callable ,
variant : str ,
variant : str ,
pretrained : bool ,
pretrained : bool ,
default_cfg : dict ,
default_cfg : dict ,
model_cfg : dict = None ,
model_cfg : Optional [ Any ] = None ,
feature_cfg : dict = None ,
feature_cfg : Optional [ dict ] = None ,
pretrained_strict : bool = True ,
pretrained_strict : bool = True ,
pretrained_filter_fn : Callable = None ,
pretrained_filter_fn : Optional[ Callable] = None ,
pretrained_custom_load : bool = False ,
pretrained_custom_load : bool = False ,
kwargs_filter : Optional [ Tuple [ str ] ] = None ,
* * kwargs ) :
* * kwargs ) :
""" Build model with specified default_cfg and optional model_cfg
This helper fn aids in the construction of a model including :
* handling default_cfg and associated pretained weight loading
* passing through optional model_cfg for models with config based arch spec
* features_only model adaptation
* pruning config / model adaptation
Args :
model_cls ( nn . Module ) : model class
variant ( str ) : model variant name
pretrained ( bool ) : load pretrained weights
default_cfg ( dict ) : model ' s default pretrained/task config
model_cfg ( Optional [ Dict ] ) : model ' s architecture config
feature_cfg ( Optional [ Dict ] : feature extraction adapter config
pretrained_strict ( bool ) : load pretrained weights strictly
pretrained_filter_fn ( Optional [ Callable ] ) : filter callable for pretrained weights
pretrained_custom_load ( bool ) : use custom load fn , to load numpy or other non PyTorch weights
kwargs_filter ( Optional [ Tuple ] ) : kwargs to filter before passing to model
* * kwargs : model args passed through to model __init__
"""
pruned = kwargs . pop ( ' pruned ' , False )
pruned = kwargs . pop ( ' pruned ' , False )
features = False
features = False
feature_cfg = feature_cfg or { }
feature_cfg = feature_cfg or { }
# Setup for featyre extraction wrapper done at end of this fn
if kwargs . pop ( ' features_only ' , False ) :
if kwargs . pop ( ' features_only ' , False ) :
features = True
features = True
feature_cfg . setdefault ( ' out_indices ' , ( 0 , 1 , 2 , 3 , 4 ) )
feature_cfg . setdefault ( ' out_indices ' , ( 0 , 1 , 2 , 3 , 4 ) )
if ' out_indices ' in kwargs :
if ' out_indices ' in kwargs :
feature_cfg [ ' out_indices ' ] = kwargs . pop ( ' out_indices ' )
feature_cfg [ ' out_indices ' ] = kwargs . pop ( ' out_indices ' )
# FIXME this next sequence of overlay default_cfg, set default kwargs, filter kwargs
# could/should be replaced by an improved configuration mechanism
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
default_cfg = overlay_external_default_cfg ( kwargs , default_cfg )
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
set_default_kwargs ( kwargs , names = ( ' num_classes ' , ' global_pool ' , ' in_chans ' ) , default_cfg = default_cfg )
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs ( kwargs , names = kwargs_filter )
# Build the model
model = model_cls ( * * kwargs ) if model_cfg is None else model_cls ( cfg = model_cfg , * * kwargs )
model = model_cls ( * * kwargs ) if model_cfg is None else model_cls ( cfg = model_cfg , * * kwargs )
model . default_cfg = deepcopy ( default_cfg )
model . default_cfg = deepcopy ( default_cfg )
if pruned :
if pruned :
model = adapt_model_from_file ( model , variant )
model = adapt_model_from_file ( model , variant )
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
# F or classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr ( model , ' num_classes ' , kwargs . get ( ' num_classes ' , 1000 ) )
num_classes_pretrained = 0 if features else getattr ( model , ' num_classes ' , kwargs . get ( ' num_classes ' , 1000 ) )
if pretrained :
if pretrained :
if pretrained_custom_load :
if pretrained_custom_load :
@ -357,9 +434,12 @@ def build_model_with_cfg(
else :
else :
load_pretrained (
load_pretrained (
model ,
model ,
num_classes = num_classes_pretrained , in_chans = kwargs . get ( ' in_chans ' , 3 ) ,
num_classes = num_classes_pretrained ,
filter_fn = pretrained_filter_fn , strict = pretrained_strict )
in_chans = kwargs . get ( ' in_chans ' , 3 ) ,
filter_fn = pretrained_filter_fn ,
strict = pretrained_strict )
# Wrap the model in a feature extraction module if enabled
if features :
if features :
feature_cls = FeatureListNet
feature_cls = FeatureListNet
if ' feature_cls ' in feature_cfg :
if ' feature_cls ' in feature_cfg :