Merge branch 'hf-save-and-push' of https://github.com/nateraw/pytorch-image-models into nateraw-hf-save-and-push

pull/1007/head
Ross Wightman 3 years ago
commit 8a83c41d7b

@ -11,11 +11,11 @@ from typing import Any, Callable, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.hub import load_state_dict_from_url
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from .layers import Conv2dSame, Linear from .layers import Conv2dSame, Linear

@ -2,10 +2,11 @@ import json
import logging import logging
import os import os
from functools import partial from functools import partial
from typing import Union, Optional from pathlib import Path
from typing import Union
import torch import torch
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX from torch.hub import HASH_REGEX, download_url_to_file, urlparse
try: try:
from torch.hub import get_dir from torch.hub import get_dir
except ImportError: except ImportError:
@ -13,8 +14,7 @@ except ImportError:
from timm import __version__ from timm import __version__
try: try:
from huggingface_hub import hf_hub_url from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url
from huggingface_hub import cached_download
cached_download = partial(cached_download, library_name="timm", library_version=__version__) cached_download = partial(cached_download, library_name="timm", library_version=__version__)
except ImportError: except ImportError:
hf_hub_url = None hf_hub_url = None
@ -94,3 +94,77 @@ def load_state_dict_from_hf(model_id: str):
cached_file = _download_from_hf(model_id, 'pytorch_model.bin') cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
state_dict = torch.load(cached_file, map_location='cpu') state_dict = torch.load(cached_file, map_location='cpu')
return state_dict return state_dict
def save_pretrained_for_hf(model, save_directory, **config_kwargs):
assert has_hf_hub(True)
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin'
torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json'
config = model.default_cfg
config['num_classes'] = config_kwargs.pop('num_classes', model.num_classes)
config['num_features'] = config_kwargs.pop('num_features', model.num_features)
config['labels'] = config_kwargs.pop('labels', [f"LABEL_{i}" for i in range(config['num_classes'])])
config.update(config_kwargs)
with config_path.open('w') as f:
json.dump(config, f, indent=2)
def push_to_hf_hub(
model,
local_dir,
repo_namespace_or_url=None,
commit_message='Add model',
use_auth_token=True,
git_email=None,
git_user=None,
revision=None,
**config_kwargs
):
if repo_namespace_or_url:
repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
else:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise ValueError(
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
"token as the `use_auth_token` argument."
)
repo_owner = HfApi().whoami(token)['name']
repo_name = Path(local_dir).name
repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}'
repo = Repository(
local_dir,
clone_from=repo_url,
use_auth_token=use_auth_token,
git_user=git_user,
git_email=git_email,
revision=revision,
)
# Prepare a default model card that includes the necessary tags to enable inference.
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
with repo.commit(commit_message):
# Save model weights and config.
save_pretrained_for_hf(model, repo.local_dir, **config_kwargs)
# Save a model card if it doesn't exist.
readme_path = Path(repo.local_dir) / 'README.md'
if not readme_path.exists():
readme_path.write_text(readme_text)
return repo.git_remote_url()

Loading…
Cancel
Save