|
|
|
@ -3,11 +3,12 @@
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
import logging
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import math
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from typing import Callable
|
|
|
|
|
from typing import Callable, Optional, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -17,13 +18,31 @@ try:
|
|
|
|
|
except ImportError:
|
|
|
|
|
from torch.hub import _get_torch_home as get_dir
|
|
|
|
|
|
|
|
|
|
from huggingface_hub import cached_download, hf_hub_url
|
|
|
|
|
|
|
|
|
|
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
|
|
|
|
from .layers import Conv2dSame, Linear
|
|
|
|
|
from ..version import __version__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_cache_dir():
|
|
|
|
|
"""
|
|
|
|
|
Returns the location of the directory where models are cached (and creates it if necessary).
|
|
|
|
|
"""
|
|
|
|
|
# Issue warning to move data if old env is set
|
|
|
|
|
if os.getenv('TORCH_MODEL_ZOO'):
|
|
|
|
|
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
|
|
|
|
|
|
|
|
|
hub_dir = get_dir()
|
|
|
|
|
model_dir = os.path.join(hub_dir, 'checkpoints')
|
|
|
|
|
|
|
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
return model_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_state_dict(checkpoint_path, use_ema=False):
|
|
|
|
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
@ -120,15 +139,7 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
|
|
|
|
|
return
|
|
|
|
|
url = cfg['url']
|
|
|
|
|
|
|
|
|
|
# Issue warning to move data if old env is set
|
|
|
|
|
if os.getenv('TORCH_MODEL_ZOO'):
|
|
|
|
|
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
|
|
|
|
|
|
|
|
|
hub_dir = get_dir()
|
|
|
|
|
model_dir = os.path.join(hub_dir, 'checkpoints')
|
|
|
|
|
|
|
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
model_dir = get_cache_dir()
|
|
|
|
|
parts = urlparse(url)
|
|
|
|
|
filename = os.path.basename(parts.path)
|
|
|
|
|
cached_file = os.path.join(model_dir, filename)
|
|
|
|
@ -173,14 +184,27 @@ def adapt_input_conv(in_chans, conv_weight):
|
|
|
|
|
return conv_weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
|
|
|
|
|
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True,
|
|
|
|
|
progress=False, hf_checkpoint=None, hf_revision=None):
|
|
|
|
|
if cfg is None:
|
|
|
|
|
cfg = getattr(model, 'default_cfg')
|
|
|
|
|
if cfg is None or 'url' not in cfg or not cfg['url']:
|
|
|
|
|
if hf_checkpoint is None:
|
|
|
|
|
hg_checkpoint = cfg.get('hf_checkpoint')
|
|
|
|
|
if hf_revision is None:
|
|
|
|
|
hg_revision = cfg.get('hf_revision')
|
|
|
|
|
if cfg is None or (('url' not in cfg or not cfg['url']) and hf_checkpoint is None):
|
|
|
|
|
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
|
|
|
|
|
if hf_checkpoint is not None:
|
|
|
|
|
# TODO, progress is ignored.
|
|
|
|
|
url = hf_hub_url(hf_checkpoint, "pytorch_model.pth", revision=hf_revision)
|
|
|
|
|
cached_filed = cached_download(
|
|
|
|
|
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
|
|
|
|
|
)
|
|
|
|
|
state_dict = torch.load(cached_filed, map_location='cpu')
|
|
|
|
|
else:
|
|
|
|
|
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
|
|
|
|
|
if filter_fn is not None:
|
|
|
|
|
state_dict = filter_fn(state_dict)
|
|
|
|
|
|
|
|
|
@ -336,6 +360,8 @@ def build_model_with_cfg(
|
|
|
|
|
pretrained_custom_load: bool = False,
|
|
|
|
|
**kwargs):
|
|
|
|
|
pruned = kwargs.pop('pruned', False)
|
|
|
|
|
hf_checkpoint = kwargs.pop('hf_checkpoint', None)
|
|
|
|
|
hf_revision = kwargs.pop('hf_revision', None)
|
|
|
|
|
features = False
|
|
|
|
|
feature_cfg = feature_cfg or {}
|
|
|
|
|
|
|
|
|
@ -353,14 +379,15 @@ def build_model_with_cfg(
|
|
|
|
|
|
|
|
|
|
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
|
|
|
|
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
|
|
|
|
if pretrained:
|
|
|
|
|
if pretrained or hf_checkpoint is not None:
|
|
|
|
|
if pretrained_custom_load:
|
|
|
|
|
load_custom_pretrained(model)
|
|
|
|
|
else:
|
|
|
|
|
load_pretrained(
|
|
|
|
|
model,
|
|
|
|
|
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
|
|
|
|
|
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
|
|
|
|
filter_fn=pretrained_filter_fn, strict=pretrained_strict,
|
|
|
|
|
hf_checkpoint=hf_checkpoint, hf_revision=hf_revision)
|
|
|
|
|
|
|
|
|
|
if features:
|
|
|
|
|
feature_cls = FeatureListNet
|
|
|
|
@ -376,3 +403,17 @@ def build_model_with_cfg(
|
|
|
|
|
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
|
|
|
|
with open(json_file, "r", encoding="utf-8") as reader:
|
|
|
|
|
text = reader.read()
|
|
|
|
|
return json.loads(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_hf_checkpoint_config(checkpoint: str, revision: Optional[str] = None):
|
|
|
|
|
url = hf_hub_url(checkpoint, "config.json", revision=revision)
|
|
|
|
|
cached_filed = cached_download(
|
|
|
|
|
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
|
|
|
|
|
)
|
|
|
|
|
return load_cfg_from_json(cached_filed)
|
|
|
|
|