|
|
|
@ -3,11 +3,12 @@
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
import logging
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import math
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from typing import Callable
|
|
|
|
|
from typing import Callable, Optional, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -183,16 +184,21 @@ def adapt_input_conv(in_chans, conv_weight):
|
|
|
|
|
return conv_weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
|
|
|
|
|
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True,
|
|
|
|
|
progress=False, hf_checkpoint=None, hf_revision=None):
|
|
|
|
|
if cfg is None:
|
|
|
|
|
cfg = getattr(model, 'default_cfg')
|
|
|
|
|
if cfg is None or 'url' not in cfg or not cfg['url'] or 'hf_checkpoint' not in cfg or not cfg['hf_checkpoint']:
|
|
|
|
|
if hf_checkpoint is None:
|
|
|
|
|
hg_checkpoint = cfg.get('hf_checkpoint')
|
|
|
|
|
if hf_revision is None:
|
|
|
|
|
hg_revision = cfg.get('hf_revision')
|
|
|
|
|
if cfg is None or (('url' not in cfg or not cfg['url']) and hf_checkpoint is None):
|
|
|
|
|
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if cfg.get('hf_checkpoint') is not None:
|
|
|
|
|
if hf_checkpoint is not None:
|
|
|
|
|
# TODO, progress is ignored.
|
|
|
|
|
url = hf_hub_url(cfg['hf_checkpoint'], "pytorch_model.pth", revision=cfg.get('hf_revision'))
|
|
|
|
|
url = hf_hub_url(hf_checkpoint, "pytorch_model.pth", revision=hf_revision)
|
|
|
|
|
cached_filed = cached_download(
|
|
|
|
|
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
|
|
|
|
|
)
|
|
|
|
@ -354,6 +360,8 @@ def build_model_with_cfg(
|
|
|
|
|
pretrained_custom_load: bool = False,
|
|
|
|
|
**kwargs):
|
|
|
|
|
pruned = kwargs.pop('pruned', False)
|
|
|
|
|
hf_checkpoint = kwargs.pop('hf_checkpoint', None)
|
|
|
|
|
hf_revision = kwargs.pop('hf_revision', None)
|
|
|
|
|
features = False
|
|
|
|
|
feature_cfg = feature_cfg or {}
|
|
|
|
|
|
|
|
|
@ -371,14 +379,15 @@ def build_model_with_cfg(
|
|
|
|
|
|
|
|
|
|
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
|
|
|
|
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
|
|
|
|
if pretrained:
|
|
|
|
|
if pretrained or hf_checkpoint is not None:
|
|
|
|
|
if pretrained_custom_load:
|
|
|
|
|
load_custom_pretrained(model)
|
|
|
|
|
else:
|
|
|
|
|
load_pretrained(
|
|
|
|
|
model,
|
|
|
|
|
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
|
|
|
|
|
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
|
|
|
|
filter_fn=pretrained_filter_fn, strict=pretrained_strict,
|
|
|
|
|
hf_checkpoint=hf_checkpoint, hf_revision=hf_revision)
|
|
|
|
|
|
|
|
|
|
if features:
|
|
|
|
|
feature_cls = FeatureListNet
|
|
|
|
@ -394,3 +403,17 @@ def build_model_with_cfg(
|
|
|
|
|
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
|
|
|
|
with open(json_file, "r", encoding="utf-8") as reader:
|
|
|
|
|
text = reader.read()
|
|
|
|
|
return json.loads(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_hf_checkpoint_config(checkpoint: str, revision: Optional[str] = None):
|
|
|
|
|
url = hf_hub_url(checkpoint, "config.json", revision=revision)
|
|
|
|
|
cached_filed = cached_download(
|
|
|
|
|
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
|
|
|
|
|
)
|
|
|
|
|
return load_cfg_from_json(cached_filed)
|
|
|
|
|