PoC for using HF model hub

pull/440/head
Sylvain Gugger 5 years ago
parent d8e69206be
commit ebe69dd4d3

@ -17,13 +17,31 @@ try:
except ImportError: except ImportError:
from torch.hub import _get_torch_home as get_dir from torch.hub import _get_torch_home as get_dir
from huggingface_hub import cached_download
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame, Linear from .layers import Conv2dSame, Linear
from ..version import __version__
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def get_cache_dir():
"""
Returns the location of the directory where models are cached (and creates it if necessary).
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
return model_dir
def load_state_dict(checkpoint_path, use_ema=False): def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
@ -120,25 +138,10 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
return return
url = cfg['url'] url = cfg['url']
# Issue warning to move data if old env is set # TODO, progress and check_hash are ignored.
if os.getenv('TORCH_MODEL_ZOO'): cached_filed = cached_download(
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
)
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
if load_fn is not None: if load_fn is not None:
load_fn(model, cached_file) load_fn(model, cached_file)
@ -180,7 +183,11 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
_logger.warning("No pretrained weights exist for this model. Using random initialization.") _logger.warning("No pretrained weights exist for this model. Using random initialization.")
return return
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu') # TODO, progress is ignored.
cached_filed = cached_download(
cfg['url'], library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
)
state_dict = torch.load(cached_filed, map_location='cpu')
if filter_fn is not None: if filter_fn is not None:
state_dict = filter_fn(state_dict) state_dict = filter_fn(state_dict)

@ -52,7 +52,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth',
interpolation='bicubic'), interpolation='bicubic'),
'resnet50d': _cfg( 'resnet50d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', url='https://huggingface.co/sgugger/resnet50d/resolve/main/pytorch_model.pth',
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnet101': _cfg(url='', interpolation='bicubic'), 'resnet101': _cfg(url='', interpolation='bicubic'),
'resnet101d': _cfg( 'resnet101d': _cfg(

Loading…
Cancel
Save