|
|
@ -17,13 +17,31 @@ try:
|
|
|
|
except ImportError:
|
|
|
|
except ImportError:
|
|
|
|
from torch.hub import _get_torch_home as get_dir
|
|
|
|
from torch.hub import _get_torch_home as get_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from huggingface_hub import cached_download
|
|
|
|
|
|
|
|
|
|
|
|
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
|
|
|
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
|
|
|
from .layers import Conv2dSame, Linear
|
|
|
|
from .layers import Conv2dSame, Linear
|
|
|
|
|
|
|
|
from ..version import __version__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_cache_dir():
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Returns the location of the directory where models are cached (and creates it if necessary).
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Issue warning to move data if old env is set
|
|
|
|
|
|
|
|
if os.getenv('TORCH_MODEL_ZOO'):
|
|
|
|
|
|
|
|
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hub_dir = get_dir()
|
|
|
|
|
|
|
|
model_dir = os.path.join(hub_dir, 'checkpoints')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
return model_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_state_dict(checkpoint_path, use_ema=False):
|
|
|
|
def load_state_dict(checkpoint_path, use_ema=False):
|
|
|
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
|
|
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
@ -120,25 +138,10 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
|
|
|
|
return
|
|
|
|
return
|
|
|
|
url = cfg['url']
|
|
|
|
url = cfg['url']
|
|
|
|
|
|
|
|
|
|
|
|
# Issue warning to move data if old env is set
|
|
|
|
# TODO, progress and check_hash are ignored.
|
|
|
|
if os.getenv('TORCH_MODEL_ZOO'):
|
|
|
|
cached_filed = cached_download(
|
|
|
|
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
|
|
|
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
|
|
|
|
|
|
|
|
)
|
|
|
|
hub_dir = get_dir()
|
|
|
|
|
|
|
|
model_dir = os.path.join(hub_dir, 'checkpoints')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parts = urlparse(url)
|
|
|
|
|
|
|
|
filename = os.path.basename(parts.path)
|
|
|
|
|
|
|
|
cached_file = os.path.join(model_dir, filename)
|
|
|
|
|
|
|
|
if not os.path.exists(cached_file):
|
|
|
|
|
|
|
|
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
|
|
|
|
|
|
|
hash_prefix = None
|
|
|
|
|
|
|
|
if check_hash:
|
|
|
|
|
|
|
|
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
|
|
|
|
|
|
|
hash_prefix = r.group(1) if r else None
|
|
|
|
|
|
|
|
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if load_fn is not None:
|
|
|
|
if load_fn is not None:
|
|
|
|
load_fn(model, cached_file)
|
|
|
|
load_fn(model, cached_file)
|
|
|
@ -180,7 +183,11 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
|
|
|
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
|
|
|
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
|
|
|
|
# TODO, progress is ignored.
|
|
|
|
|
|
|
|
cached_filed = cached_download(
|
|
|
|
|
|
|
|
cfg['url'], library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
state_dict = torch.load(cached_filed, map_location='cpu')
|
|
|
|
if filter_fn is not None:
|
|
|
|
if filter_fn is not None:
|
|
|
|
state_dict = filter_fn(state_dict)
|
|
|
|
state_dict = filter_fn(state_dict)
|
|
|
|
|
|
|
|
|
|
|
|