Make `create_model` work from a checkpoint

pull/440/head
Sylvain Gugger 4 years ago
parent 9857e12c0c
commit 4d0b5feed0

@ -1,5 +1,5 @@
from .registry import is_model, is_model_in_modules, model_entrypoint
from .helpers import load_checkpoint
from .helpers import load_checkpoint, load_hf_checkpoint_config
from .layers import set_layer_config
@ -52,7 +52,16 @@ def create_model(
create_fn = model_entrypoint(model_name)
model = create_fn(**model_args, **kwargs)
else:
raise RuntimeError('Unknown model (%s)' % model_name)
try:
model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision"))
# This does not work, but there is probably a way to have the values in the config override
# the defaults if needed.
# model_args["model_cfg"] = model_cfg
create_fn = model_entrypoint(model_cfg.pop("architecture"))
model = create_fn(**model_args, **kwargs)
except Exception as e:
raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name)
if checkpoint_path:
load_checkpoint(model, checkpoint_path)

@ -3,11 +3,12 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import json
import os
import math
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
@ -183,16 +184,21 @@ def adapt_input_conv(in_chans, conv_weight):
return conv_weight
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True,
progress=False, hf_checkpoint=None, hf_revision=None):
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url'] or 'hf_checkpoint' not in cfg or not cfg['hf_checkpoint']:
if hf_checkpoint is None:
hg_checkpoint = cfg.get('hf_checkpoint')
if hf_revision is None:
hg_revision = cfg.get('hf_revision')
if cfg is None or (('url' not in cfg or not cfg['url']) and hf_checkpoint is None):
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
if cfg.get('hf_checkpoint') is not None:
if hf_checkpoint is not None:
# TODO, progress is ignored.
url = hf_hub_url(cfg['hf_checkpoint'], "pytorch_model.pth", revision=cfg.get('hf_revision'))
url = hf_hub_url(hf_checkpoint, "pytorch_model.pth", revision=hf_revision)
cached_filed = cached_download(
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
)
@ -354,6 +360,8 @@ def build_model_with_cfg(
pretrained_custom_load: bool = False,
**kwargs):
pruned = kwargs.pop('pruned', False)
hf_checkpoint = kwargs.pop('hf_checkpoint', None)
hf_revision = kwargs.pop('hf_revision', None)
features = False
feature_cfg = feature_cfg or {}
@ -371,14 +379,15 @@ def build_model_with_cfg(
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained or hf_checkpoint is not None:
if pretrained_custom_load:
load_custom_pretrained(model)
else:
load_pretrained(
model,
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
filter_fn=pretrained_filter_fn, strict=pretrained_strict,
hf_checkpoint=hf_checkpoint, hf_revision=hf_revision)
if features:
feature_cls = FeatureListNet
@ -394,3 +403,17 @@ def build_model_with_cfg(
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
def load_hf_checkpoint_config(checkpoint: str, revision: Optional[str] = None):
url = hf_hub_url(checkpoint, "config.json", revision=revision)
cached_filed = cached_download(
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
)
return load_cfg_from_json(cached_filed)

Loading…
Cancel
Save