Merge branch 'nateraw-hf-save-and-push'

pull/1007/head
Ross Wightman 3 years ago
commit a22b85c1b9

@ -11,11 +11,11 @@ from typing import Any, Callable, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.hub import load_state_dict_from_url
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from .layers import Conv2dSame, Linear from .layers import Conv2dSame, Linear
@ -184,12 +184,12 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
if not pretrained_url and not hf_hub_id: if not pretrained_url and not hf_hub_id:
_logger.warning("No pretrained weights exist for this model. Using random initialization.") _logger.warning("No pretrained weights exist for this model. Using random initialization.")
return return
if hf_hub_id and has_hf_hub(necessary=not pretrained_url): if pretrained_url:
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
state_dict = load_state_dict_from_hf(hf_hub_id)
else:
_logger.info(f'Loading pretrained weights from url ({pretrained_url})') _logger.info(f'Loading pretrained weights from url ({pretrained_url})')
state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu') state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu')
elif hf_hub_id and has_hf_hub(necessary=True):
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
state_dict = load_state_dict_from_hf(hf_hub_id)
if filter_fn is not None: if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two # for backwards compat with filter fn that take one arg, try one first, the two
try: try:

@ -2,10 +2,11 @@ import json
import logging import logging
import os import os
from functools import partial from functools import partial
from typing import Union, Optional from pathlib import Path
from typing import Union
import torch import torch
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX from torch.hub import HASH_REGEX, download_url_to_file, urlparse
try: try:
from torch.hub import get_dir from torch.hub import get_dir
except ImportError: except ImportError:
@ -13,12 +14,12 @@ except ImportError:
from timm import __version__ from timm import __version__
try: try:
from huggingface_hub import hf_hub_url from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url
from huggingface_hub import cached_download
cached_download = partial(cached_download, library_name="timm", library_version=__version__) cached_download = partial(cached_download, library_name="timm", library_version=__version__)
_has_hf_hub = True
except ImportError: except ImportError:
hf_hub_url = None
cached_download = None cached_download = None
_has_hf_hub = False
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -53,11 +54,11 @@ def download_cached_file(url, check_hash=True, progress=False):
def has_hf_hub(necessary=False): def has_hf_hub(necessary=False):
if hf_hub_url is None and necessary: if not _has_hf_hub and necessary:
# if no HF Hub module installed and it is necessary to continue, raise error # if no HF Hub module installed and it is necessary to continue, raise error
raise RuntimeError( raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return hf_hub_url is not None return _has_hf_hub
def hf_split(hf_id): def hf_split(hf_id):
@ -94,3 +95,77 @@ def load_state_dict_from_hf(model_id: str):
cached_file = _download_from_hf(model_id, 'pytorch_model.bin') cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
state_dict = torch.load(cached_file, map_location='cpu') state_dict = torch.load(cached_file, map_location='cpu')
return state_dict return state_dict
def save_for_hf(model, save_directory, model_config=None):
assert has_hf_hub(True)
model_config = model_config or {}
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin'
torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json'
hf_config = model.default_cfg
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
hf_config['num_features'] = model_config.pop('num_features', model.num_features)
hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])])
hf_config.update(model_config)
with config_path.open('w') as f:
json.dump(hf_config, f, indent=2)
def push_to_hf_hub(
model,
local_dir,
repo_namespace_or_url=None,
commit_message='Add model',
use_auth_token=True,
git_email=None,
git_user=None,
revision=None,
model_config=None,
):
if repo_namespace_or_url:
repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
else:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise ValueError(
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
"token as the `use_auth_token` argument."
)
repo_owner = HfApi().whoami(token)['name']
repo_name = Path(local_dir).name
repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}'
repo = Repository(
local_dir,
clone_from=repo_url,
use_auth_token=use_auth_token,
git_user=git_user,
git_email=git_email,
revision=revision,
)
# Prepare a default model card that includes the necessary tags to enable inference.
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
with repo.commit(commit_message):
# Save model weights and config.
save_for_hf(model, repo.local_dir, model_config=model_config)
# Save a model card if it doesn't exist.
readme_path = Path(repo.local_dir) / 'README.md'
if not readme_path.exists():
readme_path.write_text(readme_text)
return repo.git_remote_url()

Loading…
Cancel
Save