Merge branch 'rwightman:main' into main

pull/1583/head
Fredo Guan 2 years ago committed by GitHub
commit 84178fca60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,7 +21,25 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
## What's New ## What's New
# Dec 6, 2022 ### 🤗 Survey: Feedback Appreciated 🤗
For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing.
If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts:
[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏
### Dec 8, 2022
* Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some)
* original source: https://github.com/baaivision/EVA
| model | top1 | param_count | gmac | macts | hub |
|:------------------------------------------|-----:|------------:|------:|------:|:----------------------------------------|
| eva_large_patch14_336.in22k_ft_in22k_in1k | 89.2 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) |
| eva_large_patch14_336.in22k_ft_in1k | 88.7 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) |
| eva_large_patch14_196.in22k_ft_in22k_in1k | 88.6 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) |
| eva_large_patch14_196.in22k_ft_in1k | 87.9 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) |
### Dec 6, 2022
* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`. * Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`.
* original source: https://github.com/baaivision/EVA * original source: https://github.com/baaivision/EVA
* paper: https://arxiv.org/abs/2211.07636 * paper: https://arxiv.org/abs/2211.07636
@ -33,7 +51,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
| eva_giant_patch14_336.clip_ft_in1k | 89.4 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) | | eva_giant_patch14_336.clip_ft_in1k | 89.4 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) |
| eva_giant_patch14_224.clip_ft_in1k | 89.1 | 1012.6 | 267.2 | 192.6 | [link](https://huggingface.co/BAAI/EVA) | | eva_giant_patch14_224.clip_ft_in1k | 89.1 | 1012.6 | 267.2 | 192.6 | [link](https://huggingface.co/BAAI/EVA) |
# Dec 5, 2022 ### Dec 5, 2022
* Pre-release (`0.8.0dev0`) of multi-weight support (`model_arch.pretrained_tag`). Install with `pip install --pre timm` * Pre-release (`0.8.0dev0`) of multi-weight support (`model_arch.pretrained_tag`). Install with `pip install --pre timm`
* vision_transformer, maxvit, convnext are the first three model impl w/ support * vision_transformer, maxvit, convnext are the first three model impl w/ support

@ -16,7 +16,7 @@ import argparse
import os import os
import glob import glob
import hashlib import hashlib
from timm.models.helpers import load_state_dict from timm.models import load_state_dict
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH', parser.add_argument('--input', default='', type=str, metavar='PATH',

@ -19,7 +19,8 @@ import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from timm.data import resolve_data_config from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models, set_fast_norm from timm.layers import set_fast_norm
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2 from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry

@ -13,7 +13,7 @@ import os
import hashlib import hashlib
import shutil import shutil
from collections import OrderedDict from collections import OrderedDict
from timm.models.helpers import load_state_dict from timm.models import load_state_dict
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',

@ -1,4 +1,3 @@
dependencies = ['torch'] dependencies = ['torch']
from timm.models import registry import timm
globals().update(timm.models._registry._model_entrypoints)
globals().update(registry._model_entrypoints)

@ -5,11 +5,11 @@ An example inference script that outputs top-k class ids for images in a folder
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
""" """
import os
import time
import argparse import argparse
import json import json
import logging import logging
import os
import time
from contextlib import suppress from contextlib import suppress
from functools import partial from functools import partial
@ -17,12 +17,11 @@ import numpy as np
import pandas as pd import pandas as pd
import torch import torch
from timm.models import create_model, apply_test_time_pool, load_checkpoint
from timm.data import create_dataset, create_loader, resolve_data_config from timm.data import create_dataset, create_loader, resolve_data_config
from timm.layers import apply_test_time_pool
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
try: try:
from apex import amp from apex import amp
has_apex = True has_apex = True

@ -1,10 +1,7 @@
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import platform
import os
from timm.models.layers import create_act_layer, get_act_layer, set_layer_config from timm.layers import create_act_layer, set_layer_config
class MLP(nn.Module): class MLP(nn.Module):

@ -14,7 +14,7 @@ except ImportError:
import timm import timm
from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
from timm.models.fx_features import _leaf_modules, _autowrap_functions from timm.models._features_fx import _leaf_modules, _autowrap_functions
if hasattr(torch._C, '_jit_set_profiling_executor'): if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests # legacy executor is too slow to compile large models for unit tests

@ -1,4 +1,4 @@
from .version import __version__ from .version import __version__
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \ from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable, \
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -1,4 +1,4 @@
""" AutoAugment, RandAugment, and AugMix for PyTorch """ AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch
This code implements the searched ImageNet policies with various tweaks and improvements and This code implements the searched ImageNet policies with various tweaks and improvements and
does not include any of the search code. does not include any of the search code.
@ -9,18 +9,24 @@ AA and RA Implementation adapted from:
AugMix adapted from: AugMix adapted from:
https://github.com/google-research/augmix https://github.com/google-research/augmix
3-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md
Papers: Papers:
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
3-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118
Hacked together by / Copyright 2019, Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import random import random
import math import math
import re import re
from PIL import Image, ImageOps, ImageEnhance, ImageChops from functools import partial
from typing import Dict, List, Optional, Union
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
import PIL import PIL
import numpy as np import numpy as np
@ -175,6 +181,24 @@ def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor) return ImageEnhance.Sharpness(img).enhance(factor)
def gaussian_blur(img, factor, **__):
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
return img
def gaussian_blur_rand(img, factor, **__):
radius_min = 0.1
radius_max = 2.0
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
return img
def desaturate(img, factor, **_):
factor = min(1., max(0., 1. - factor))
# enhance factor 0 = grayscale, 1.0 = no-change
return ImageEnhance.Color(img).enhance(factor)
def _randomly_negate(v): def _randomly_negate(v):
"""With 50% prob, negate the value""" """With 50% prob, negate the value"""
return -v if random.random() > 0.5 else v return -v if random.random() > 0.5 else v
@ -200,6 +224,14 @@ def _enhance_increasing_level_to_arg(level, _hparams):
return level, return level,
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
level = (level / _LEVEL_DENOM)
min_val + (max_val - min_val) * level
if clamp:
level = max(min_val, min(max_val, level))
return level,
def _shear_level_to_arg(level, _hparams): def _shear_level_to_arg(level, _hparams):
# range [-0.3, 0.3] # range [-0.3, 0.3]
level = (level / _LEVEL_DENOM) * 0.3 level = (level / _LEVEL_DENOM) * 0.3
@ -246,7 +278,7 @@ def _posterize_original_level_to_arg(level, _hparams):
def _solarize_level_to_arg(level, _hparams): def _solarize_level_to_arg(level, _hparams):
# range [0, 256] # range [0, 256]
# intensity/severity of augmentation decreases with level # intensity/severity of augmentation decreases with level
return int((level / _LEVEL_DENOM) * 256), return min(256, int((level / _LEVEL_DENOM) * 256)),
def _solarize_increasing_level_to_arg(level, _hparams): def _solarize_increasing_level_to_arg(level, _hparams):
@ -257,7 +289,7 @@ def _solarize_increasing_level_to_arg(level, _hparams):
def _solarize_add_level_to_arg(level, _hparams): def _solarize_add_level_to_arg(level, _hparams):
# range [0, 110] # range [0, 110]
return int((level / _LEVEL_DENOM) * 110), return min(128, int((level / _LEVEL_DENOM) * 110)),
LEVEL_TO_ARG = { LEVEL_TO_ARG = {
@ -286,6 +318,9 @@ LEVEL_TO_ARG = {
'TranslateY': _translate_abs_level_to_arg, 'TranslateY': _translate_abs_level_to_arg,
'TranslateXRel': _translate_rel_level_to_arg, 'TranslateXRel': _translate_rel_level_to_arg,
'TranslateYRel': _translate_rel_level_to_arg, 'TranslateYRel': _translate_rel_level_to_arg,
'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0),
'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0),
'GaussianBlurRand': _minmax_level_to_arg,
} }
@ -314,6 +349,9 @@ NAME_TO_OP = {
'TranslateY': translate_y_abs, 'TranslateY': translate_y_abs,
'TranslateXRel': translate_x_rel, 'TranslateXRel': translate_x_rel,
'TranslateYRel': translate_y_rel, 'TranslateYRel': translate_y_rel,
'Desaturate': desaturate,
'GaussianBlur': gaussian_blur,
'GaussianBlurRand': gaussian_blur_rand,
} }
@ -347,6 +385,7 @@ class AugmentOp:
if self.magnitude_std > 0: if self.magnitude_std > 0:
# magnitude randomization enabled # magnitude randomization enabled
if self.magnitude_std == float('inf'): if self.magnitude_std == float('inf'):
# inf == uniform sampling
magnitude = random.uniform(0, magnitude) magnitude = random.uniform(0, magnitude)
elif self.magnitude_std > 0: elif self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = random.gauss(magnitude, self.magnitude_std)
@ -499,6 +538,16 @@ def auto_augment_policy_originalr(hparams):
return pc return pc
def auto_augment_policy_3a(hparams):
policy = [
[('Solarize', 1.0, 5)], # 128 solarize threshold @ 5 magnitude
[('Desaturate', 1.0, 10)], # grayscale at 10 magnitude
[('GaussianBlurRand', 1.0, 10)],
]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy(name='v0', hparams=None): def auto_augment_policy(name='v0', hparams=None):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
if name == 'original': if name == 'original':
@ -509,6 +558,8 @@ def auto_augment_policy(name='v0', hparams=None):
return auto_augment_policy_v0(hparams) return auto_augment_policy_v0(hparams)
elif name == 'v0r': elif name == 'v0r':
return auto_augment_policy_v0r(hparams) return auto_augment_policy_v0r(hparams)
elif name == '3a':
return auto_augment_policy_3a(hparams)
else: else:
assert False, 'Unknown AA policy (%s)' % name assert False, 'Unknown AA policy (%s)' % name
@ -534,19 +585,23 @@ class AutoAugment:
return fs return fs
def auto_augment_transform(config_str, hparams): def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
""" """
Create a AutoAugment transform Create a AutoAugment transform
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by Args:
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
The remaining sections, not order sepecific determine dashes ('-').
The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
The remaining sections:
'mstd' - float std deviation of magnitude noise applied 'mstd' - float std deviation of magnitude noise applied
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme hparams: Other hparams (kwargs) for the AutoAugmentation scheme
:return: A PyTorch compatible Transform Returns:
A PyTorch compatible Transform
""" """
config = config_str.split('-') config = config_str.split('-')
policy_name = config[0] policy_name = config[0]
@ -605,42 +660,80 @@ _RAND_INCREASING_TRANSFORMS = [
] ]
_RAND_3A = [
'SolarizeIncreasing',
'Desaturate',
'GaussianBlur',
]
_RAND_CHOICE_3A = {
'SolarizeIncreasing': 6,
'Desaturate': 6,
'GaussianBlur': 6,
'Rotate': 3,
'ShearX': 2,
'ShearY': 2,
'PosterizeIncreasing': 1,
'AutoContrast': 1,
'ColorIncreasing': 1,
'SharpnessIncreasing': 1,
'ContrastIncreasing': 1,
'BrightnessIncreasing': 1,
'Equalize': 1,
'Invert': 1,
}
# These experimental weights are based loosely on the relative improvements mentioned in paper. # These experimental weights are based loosely on the relative improvements mentioned in paper.
# They may not result in increased performance, but could likely be tuned to so. # They may not result in increased performance, but could likely be tuned to so.
_RAND_CHOICE_WEIGHTS_0 = { _RAND_CHOICE_WEIGHTS_0 = {
'Rotate': 0.3, 'Rotate': 3,
'ShearX': 0.2, 'ShearX': 2,
'ShearY': 0.2, 'ShearY': 2,
'TranslateXRel': 0.1, 'TranslateXRel': 1,
'TranslateYRel': 0.1, 'TranslateYRel': 1,
'Color': .025, 'ColorIncreasing': .25,
'Sharpness': 0.025, 'SharpnessIncreasing': 0.25,
'AutoContrast': 0.025, 'AutoContrast': 0.25,
'Solarize': .005, 'SolarizeIncreasing': .05,
'SolarizeAdd': .005, 'SolarizeAdd': .05,
'Contrast': .005, 'ContrastIncreasing': .05,
'Brightness': .005, 'BrightnessIncreasing': .05,
'Equalize': .005, 'Equalize': .05,
'Posterize': 0, 'PosterizeIncreasing': 0.05,
'Invert': 0, 'Invert': 0.05,
} }
def _select_rand_weights(weight_idx=0, transforms=None): def _get_weighted_transforms(transforms: Dict):
transforms = transforms or _RAND_TRANSFORMS transforms, probs = list(zip(*transforms.items()))
assert weight_idx == 0 # only one set of weights currently probs = np.array(probs)
rand_weights = _RAND_CHOICE_WEIGHTS_0 probs = probs / np.sum(probs)
probs = [rand_weights[k] for k in transforms] return transforms, probs
probs /= np.sum(probs)
return probs
def rand_augment_choices(name: str, increasing=True):
if name == 'weights':
return _RAND_CHOICE_WEIGHTS_0
elif name == '3aw':
return _RAND_CHOICE_3A
elif name == '3a':
return _RAND_3A
else:
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
def rand_augment_ops(magnitude=10, hparams=None, transforms=None): def rand_augment_ops(
magnitude: Union[int, float] = 10,
prob: float = 0.5,
hparams: Optional[Dict] = None,
transforms: Optional[Union[Dict, List]] = None,
):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS transforms = transforms or _RAND_TRANSFORMS
return [AugmentOp( return [AugmentOp(
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
class RandAugment: class RandAugment:
@ -648,11 +741,16 @@ class RandAugment:
self.ops = ops self.ops = ops
self.num_layers = num_layers self.num_layers = num_layers
self.choice_weights = choice_weights self.choice_weights = choice_weights
print(self.ops, self.choice_weights)
def __call__(self, img): def __call__(self, img):
# no replacement when using weighted choice # no replacement when using weighted choice
ops = np.random.choice( ops = np.random.choice(
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) self.ops,
self.num_layers,
replace=self.choice_weights is None,
p=self.choice_weights,
)
for op in ops: for op in ops:
img = op(img) img = op(img)
return img return img
@ -665,34 +763,48 @@ class RandAugment:
return fs return fs
def rand_augment_transform(config_str, hparams): def rand_augment_transform(
config_str: str,
hparams: Optional[Dict] = None,
transforms: Optional[Union[str, Dict, List]] = None,
):
""" """
Create a RandAugment transform Create a RandAugment transform
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by Args:
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
sections, not order sepecific determine by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
The remaining sections, not order sepecific determine
'm' - integer magnitude of rand augment 'm' - integer magnitude of rand augment
'n' - integer num layers (number of transform ops selected per image) 'n' - integer num layers (number of transform ops selected per image)
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 'p' - float probability of applying each layer (default 0.5)
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
't' - str name of transform set to use
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme
:return: A PyTorch compatible Transform Returns:
A PyTorch compatible Transform
""" """
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10) magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
num_layers = 2 # default to 2 ops per image num_layers = 2 # default to 2 ops per image
weight_idx = None # default to no probability weights for op choice increasing = False
transforms = _RAND_TRANSFORMS prob = 0.5
config = config_str.split('-') config = config_str.split('-')
assert config[0] == 'rand' assert config[0] == 'rand'
config = config[1:] config = config[1:]
for c in config: for c in config:
if c.startswith('t'):
# NOTE old 'w' key was removed, 'w0' is not equivalent to 'tweights'
val = str(c[1:])
if transforms is None:
transforms = val
else:
# numeric options
cs = re.split(r'(\d.*)', c) cs = re.split(r'(\d.*)', c)
if len(cs) < 2: if len(cs) < 2:
continue continue
@ -709,17 +821,26 @@ def rand_augment_transform(config_str, hparams):
hparams.setdefault('magnitude_max', int(val)) hparams.setdefault('magnitude_max', int(val))
elif key == 'inc': elif key == 'inc':
if bool(val): if bool(val):
transforms = _RAND_INCREASING_TRANSFORMS increasing = True
elif key == 'm': elif key == 'm':
magnitude = int(val) magnitude = int(val)
elif key == 'n': elif key == 'n':
num_layers = int(val) num_layers = int(val)
elif key == 'w': elif key == 'p':
weight_idx = int(val) prob = float(val)
else: else:
assert False, 'Unknown RandAugment config section' assert False, 'Unknown RandAugment config section'
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) if isinstance(transforms, str):
transforms = rand_augment_choices(transforms, increasing=increasing)
elif transforms is None:
transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
choice_weights = None
if isinstance(transforms, Dict):
transforms, choice_weights = _get_weighted_transforms(transforms)
ra_ops = rand_augment_ops(magnitude=magnitude, prob=prob, hparams=hparams, transforms=transforms)
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
@ -740,11 +861,19 @@ _AUGMIX_TRANSFORMS = [
] ]
def augmix_ops(magnitude=10, hparams=None, transforms=None): def augmix_ops(
magnitude: Union[int, float] = 10,
hparams: Optional[Dict] = None,
transforms: Optional[Union[str, Dict, List]] = None,
):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _AUGMIX_TRANSFORMS transforms = transforms or _AUGMIX_TRANSFORMS
return [AugmentOp( return [AugmentOp(
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] name,
prob=1.0,
magnitude=magnitude,
hparams=hparams
) for name in transforms]
class AugMixAugment: class AugMixAugment:
@ -820,12 +949,13 @@ class AugMixAugment:
return fs return fs
def augment_and_mix_transform(config_str, hparams): def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None):
""" Create AugMix PyTorch transform """ Create AugMix PyTorch transform
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by Args:
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
sections, not order sepecific determine by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
The remaining sections, not order sepecific determine
'm' - integer magnitude (severity) of augmentation mix (default: 3) 'm' - integer magnitude (severity) of augmentation mix (default: 3)
'w' - integer width of augmentation chain (default: 3) 'w' - integer width of augmentation chain (default: 3)
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
@ -833,9 +963,10 @@ def augment_and_mix_transform(config_str, hparams):
'mstd' - float std deviation of magnitude noise applied (default: 0) 'mstd' - float std deviation of magnitude noise applied (default: 0)
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
:param hparams: Other hparams (kwargs) for the Augmentation transforms hparams: Other hparams (kwargs) for the Augmentation transforms
:return: A PyTorch compatible Transform Returns:
A PyTorch compatible Transform
""" """
magnitude = 3 magnitude = 3
width = 3 width = 3

@ -1,6 +1,7 @@
import os import os
import pickle import pickle
def load_class_map(map_or_filename, root=''): def load_class_map(map_or_filename, root=''):
if isinstance(map_or_filename, dict): if isinstance(map_or_filename, dict):
assert dict, 'class_map dict must be non-empty' assert dict, 'class_map dict must be non-empty'

@ -59,6 +59,7 @@ def transforms_imagenet_train(
re_count=1, re_count=1,
re_num_splits=0, re_num_splits=0,
separate=False, separate=False,
force_color_jitter=False,
): ):
""" """
If separate==True, the transforms are returned as a tuple of 3 separate transforms If separate==True, the transforms are returned as a tuple of 3 separate transforms
@ -77,8 +78,12 @@ def transforms_imagenet_train(
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
secondary_tfl = [] secondary_tfl = []
disable_color_jitter = False
if auto_augment: if auto_augment:
assert isinstance(auto_augment, str) assert isinstance(auto_augment, str)
# color jitter is typically disabled if AA/RA on,
# this allows override without breaking old hparm cfgs
disable_color_jitter = not (force_color_jitter or '3a' in auto_augment)
if isinstance(img_size, (tuple, list)): if isinstance(img_size, (tuple, list)):
img_size_min = min(img_size) img_size_min = min(img_size)
else: else:
@ -96,8 +101,9 @@ def transforms_imagenet_train(
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
else: else:
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
elif color_jitter is not None:
# color jitter is enabled when not using AA if color_jitter is not None and not disable_color_jitter:
# color jitter is enabled when not using AA or when forced
if isinstance(color_jitter, (list, tuple)): if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue # or 4 if also augmenting hue

@ -0,0 +1,44 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config
from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d
from .create_norm import get_norm_layer, create_norm_layer
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct
from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .trace_utils import _assert, _float_to_int
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_

@ -65,12 +65,18 @@ from .xception import *
from .xception_aligned import * from .xception_aligned import *
from .xcit import * from .xcit import *
from .factory import create_model, parse_model_name, safe_model_name from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \
from .helpers import load_checkpoint, resume_checkpoint, model_parameters set_pretrained_download_progress, set_pretrained_check_hash
from .layers import TestTimePoolHead, apply_test_time_pool from ._factory import create_model, parse_model_name, safe_model_name
from .layers import convert_splitbn_model, convert_sync_batchnorm from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
from .layers import set_fast_norm register_notrace_module, register_notrace_function
from .pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\ from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
from ._pretrained import PretrainedCfg, DefaultCfg, \
filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
from ._prune import adapt_model_from_string
from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -0,0 +1,399 @@
import dataclasses
import logging
from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple
from torch import nn as nn
from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
from timm.models._registry import get_pretrained_cfg
_logger = logging.getLogger(__name__)
# Global variables for rarely used pretrained checkpoint download progress and hash check.
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
_DOWNLOAD_PROGRESS = False
_CHECK_HASH = False
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
def _resolve_pretrained_source(pretrained_cfg):
cfg_source = pretrained_cfg.get('source', '')
pretrained_url = pretrained_cfg.get('url', None)
pretrained_file = pretrained_cfg.get('file', None)
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
# resolve where to load pretrained weights from
load_from = ''
pretrained_loc = ''
if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
# hf-hub specified as source via model identifier
load_from = 'hf-hub'
assert hf_hub_id
pretrained_loc = hf_hub_id
else:
# default source == timm or unspecified
if pretrained_file:
load_from = 'file'
pretrained_loc = pretrained_file
elif pretrained_url:
load_from = 'url'
pretrained_loc = pretrained_url
elif hf_hub_id and has_hf_hub(necessary=True):
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
# if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
return load_from, pretrained_loc
def set_pretrained_download_progress(enable=True):
""" Set download progress for pretrained weights on/off (globally). """
global _DOWNLOAD_PROGRESS
_DOWNLOAD_PROGRESS = enable
def set_pretrained_check_hash(enable=True):
""" Set hash checking for pretrained weights on/off (globally). """
global _CHECK_HASH
_CHECK_HASH = enable
def load_custom_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
load_fn: Optional[Callable] = None,
):
r"""Loads a custom (read non .pth) weight file
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
a passed in custom load fun, or the `load_pretrained` model member fn.
If the object is already present in `model_dir`, it's deserialized and returned.
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
Args:
model: The instantiated model to load weights into
pretrained_cfg (dict): Default pretrained model cfg
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
_logger.warning("Invalid pretrained config, cannot load weights.")
return
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if not load_from:
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
if load_from == 'hf-hub': # FIXME
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
elif load_from == 'url':
pretrained_loc = download_cached_file(
pretrained_loc,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS
)
if load_fn is not None:
load_fn(model, pretrained_loc)
elif hasattr(model, 'load_pretrained'):
model.load_pretrained(pretrained_loc)
else:
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
def load_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
num_classes: int = 1000,
in_chans: int = 3,
filter_fn: Optional[Callable] = None,
strict: bool = True,
):
""" Load pretrained checkpoint
Args:
model (nn.Module) : PyTorch model module
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
num_classes (int): num_classes for target model
in_chans (int): in_chans for target model
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
strict (bool): strict load of checkpoint
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
_logger.warning("Invalid pretrained config, cannot load weights.")
return
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if load_from == 'file':
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
state_dict = load_state_dict(pretrained_loc)
elif load_from == 'url':
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc)
else:
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
return
if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two
try:
state_dict = filter_fn(state_dict)
except TypeError:
state_dict = filter_fn(state_dict, model)
input_convs = pretrained_cfg.get('first_conv', None)
if input_convs is not None and in_chans != 3:
if isinstance(input_convs, str):
input_convs = (input_convs,)
for input_conv_name in input_convs:
weight_name = input_conv_name + '.weight'
try:
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
_logger.info(
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
except NotImplementedError as e:
del state_dict[weight_name]
strict = False
_logger.warning(
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
classifiers = pretrained_cfg.get('classifier', None)
label_offset = pretrained_cfg.get('label_offset', 0)
if classifiers is not None:
if isinstance(classifiers, str):
classifiers = (classifiers,)
if num_classes != pretrained_cfg['num_classes']:
for classifier_name in classifiers:
# completely discard fully connected if model num_classes doesn't match pretrained weights
state_dict.pop(classifier_name + '.weight', None)
state_dict.pop(classifier_name + '.bias', None)
strict = False
elif label_offset > 0:
for classifier_name in classifiers:
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
model.load_state_dict(state_dict, strict=strict)
def pretrained_cfg_for_features(pretrained_cfg):
pretrained_cfg = deepcopy(pretrained_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
for tr in to_remove:
pretrained_cfg.pop(tr, None)
return pretrained_cfg
def _filter_kwargs(kwargs, names):
if not kwargs or not names:
return
for n in names:
kwargs.pop(n, None)
def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
""" Update the default_cfg and kwargs before passing to model
Args:
pretrained_cfg: input pretrained cfg (updated in-place)
kwargs: keyword args passed to model build fn (updated in-place)
kwargs_filter: keyword arg keys that must be removed before model __init__
"""
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
if pretrained_cfg.get('fixed_input_size', False):
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
default_kwarg_names += ('img_size',)
for n in default_kwarg_names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# pretrained_cfg has one input_size=(C, H ,W) entry
if n == 'img_size':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[-2:])
elif n == 'in_chans':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[0])
else:
default_val = pretrained_cfg.get(n, None)
if default_val is not None:
kwargs.setdefault(n, pretrained_cfg[n])
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
_filter_kwargs(kwargs, names=kwargs_filter)
def resolve_pretrained_cfg(
variant: str,
pretrained_cfg=None,
pretrained_cfg_overlay=None,
) -> PretrainedCfg:
model_with_tag = variant
pretrained_tag = None
if pretrained_cfg:
if isinstance(pretrained_cfg, dict):
# pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
pretrained_cfg = PretrainedCfg(**pretrained_cfg)
elif isinstance(pretrained_cfg, str):
pretrained_tag = pretrained_cfg
pretrained_cfg = None
# fallback to looking up pretrained cfg in model registry by variant identifier
if not pretrained_cfg:
if pretrained_tag:
model_with_tag = '.'.join([variant, pretrained_tag])
pretrained_cfg = get_pretrained_cfg(model_with_tag)
if not pretrained_cfg:
_logger.warning(
f"No pretrained configuration specified for {model_with_tag} model. Using a default."
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
pretrained_cfg = PretrainedCfg() # instance with defaults
pretrained_cfg_overlay = pretrained_cfg_overlay or {}
if not pretrained_cfg.architecture:
pretrained_cfg_overlay.setdefault('architecture', variant)
pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)
return pretrained_cfg
def build_model_with_cfg(
model_cls: Callable,
variant: str,
pretrained: bool,
pretrained_cfg: Optional[Dict] = None,
pretrained_cfg_overlay: Optional[Dict] = None,
model_cfg: Optional[Any] = None,
feature_cfg: Optional[Dict] = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Optional[Callable] = None,
kwargs_filter: Optional[Tuple[str]] = None,
**kwargs,
):
""" Build model with specified default_cfg and optional model_cfg
This helper fn aids in the construction of a model including:
* handling default_cfg and associated pretrained weight loading
* passing through optional model_cfg for models with config based arch spec
* features_only model adaptation
* pruning config / model adaptation
Args:
model_cls (nn.Module): model class
variant (str): model variant name
pretrained (bool): load pretrained weights
pretrained_cfg (dict): model's pretrained weight/task config
model_cfg (Optional[Dict]): model's architecture config
feature_cfg (Optional[Dict]: feature extraction adapter config
pretrained_strict (bool): load pretrained weights strictly
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
**kwargs: model args passed through to model __init__
"""
pruned = kwargs.pop('pruned', False)
features = False
feature_cfg = feature_cfg or {}
# resolve and update model pretrained config and model kwargs
pretrained_cfg = resolve_pretrained_cfg(
variant,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay
)
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
pretrained_cfg = pretrained_cfg.to_dict()
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
# Setup for feature extraction wrapper done at end of this fn
if kwargs.pop('features_only', False):
features = True
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
# Instantiate the model
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat
if pruned:
model = adapt_model_from_file(model, variant)
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_cfg.get('custom_load', False):
load_custom_pretrained(
model,
pretrained_cfg=pretrained_cfg,
)
else:
load_pretrained(
model,
pretrained_cfg=pretrained_cfg,
num_classes=num_classes_pretrained,
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict,
)
# Wrap the model in a feature extraction module if enabled
if features:
feature_cls = FeatureListNet
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat
return model

@ -2,13 +2,12 @@
Hacked together by / Copyright 2019, Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer
__all__ = [ __all__ = [
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']

@ -14,8 +14,8 @@ from functools import partial
import torch.nn as nn import torch.nn as nn
from .efficientnet_blocks import * from ._efficientnet_blocks import *
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']

@ -0,0 +1,103 @@
import os
from typing import Any, Dict, Optional, Union
from urllib.parse import urlsplit
from timm.layers import set_layer_config
from ._pretrained import PretrainedCfg, split_model_name_tag
from ._helpers import load_checkpoint
from ._hub import load_model_config_from_hf
from ._registry import is_model, model_entrypoint
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
def parse_model_name(model_name):
if model_name.startswith('hf_hub'):
# NOTE for backwards compat, deprecate hf_hub use
model_name = model_name.replace('hf_hub', 'hf-hub')
parsed = urlsplit(model_name)
assert parsed.scheme in ('', 'timm', 'hf-hub')
if parsed.scheme == 'hf-hub':
# FIXME may use fragment as revision, currently `@` in URI path
return parsed.scheme, parsed.path
else:
model_name = os.path.split(parsed.path)[-1]
return 'timm', model_name
def safe_model_name(model_name, remove_source=True):
# return a filename / path safe model name
def make_safe(name):
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
if remove_source:
model_name = parse_model_name(model_name)[-1]
return make_safe(model_name)
def create_model(
model_name: str,
pretrained: bool = False,
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
checkpoint_path: str = '',
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
**kwargs,
):
"""Create a model
Lookup model's entrypoint function and pass relevant args to create a new model.
**kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg()
and then the model class __init__(). kwargs values set to None are pruned before passing.
Args:
model_name (str): name of model to instantiate
pretrained (bool): load pretrained ImageNet-1k weights if true
pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model
pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these
checkpoint_path (str): path of checkpoint to load _after_ the model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
Keyword Args:
drop_rate (float): dropout rate for training (default: 0.0)
global_pool (str): global pool type (default: 'avg')
**: other kwargs are consumed by builder or model __init__()
"""
# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
model_source, model_name = parse_model_name(model_name)
if model_source == 'hf-hub':
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
else:
model_name, pretrained_tag = split_model_name_tag(model_name)
if not pretrained_cfg:
# a valid pretrained_cfg argument takes priority over tag in model name
pretrained_cfg = pretrained_tag
if not is_model(model_name):
raise RuntimeError('Unknown model (%s)' % model_name)
create_fn = model_entrypoint(model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(
pretrained=pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay,
**kwargs,
)
if checkpoint_path:
load_checkpoint(model, checkpoint_path)
return model

@ -0,0 +1,287 @@
""" PyTorch Feature Extraction Helpers
A collection of classes, functions, modules to help extract features from models
and provide a common interface for describing them.
The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
Hacked together by / Copyright 2020 Ross Wightman
"""
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
class FeatureInfo:
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
prev_reduction = 1
for fi in feature_info:
# sanity check the mandatory fields, there may be additional fields depending on the model
assert 'num_chs' in fi and fi['num_chs'] > 0
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
prev_reduction = fi['reduction']
assert 'module' in fi
self.out_indices = out_indices
self.info = feature_info
def from_other(self, out_indices: Tuple[int]):
return FeatureInfo(deepcopy(self.info), out_indices)
def get(self, key, idx=None):
""" Get value by key at specified index (indices)
if idx == None, returns value for key at each output index
if idx is an integer, return value for that feature module index (ignoring output indices)
if idx is a list/tupple, return value for each module index (ignoring output indices)
"""
if idx is None:
return [self.info[i][key] for i in self.out_indices]
if isinstance(idx, (tuple, list)):
return [self.info[i][key] for i in idx]
else:
return self.info[idx][key]
def get_dicts(self, keys=None, idx=None):
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
"""
if idx is None:
if keys is None:
return [self.info[i] for i in self.out_indices]
else:
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
if isinstance(idx, (tuple, list)):
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
else:
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
def channels(self, idx=None):
""" feature channels accessor
"""
return self.get('num_chs', idx)
def reduction(self, idx=None):
""" feature reduction (output stride) accessor
"""
return self.get('reduction', idx)
def module_name(self, idx=None):
""" feature module name accessor
"""
return self.get('module', idx)
def __getitem__(self, item):
return self.info[item]
def __len__(self):
return len(self.info)
class FeatureHooks:
""" Feature Hook Helper
This module helps with the setup and extraction of hooks for extracting features from
internal nodes in a model by node name. This works quite well in eager Python but needs
redesign for torchscript.
"""
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
# setup feature hooks
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
hook_name = h['module']
m = modules[hook_name]
hook_id = out_map[i] if out_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h.get('hook_type', default_hook_type)
if hook_type == 'forward_pre':
m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward':
m.register_forward_hook(hook_fn)
else:
assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
if isinstance(x, tuple):
x = x[0] # unwrap input tuple
self._feature_outputs[x.device][hook_id] = x
def get_output(self, device) -> Dict[str, torch.tensor]:
output = self._feature_outputs[device]
self._feature_outputs[device] = OrderedDict() # clear after reading
return output
def _module_list(module, flatten_sequential=False):
# a yield/iter would be better for this but wouldn't be compatible with torchscript
ml = []
for name, module in module.named_children():
if flatten_sequential and isinstance(module, nn.Sequential):
# first level of Sequential containers is flattened into containing model
for child_name, child_module in module.named_children():
combined = [name, child_name]
ml.append(('_'.join(combined), '.'.join(combined), child_module))
else:
ml.append((name, name, module))
return ml
def _get_feature_info(net, out_indices):
feature_info = getattr(net, 'feature_info')
if isinstance(feature_info, FeatureInfo):
return feature_info.from_other(out_indices)
elif isinstance(feature_info, (list, tuple)):
return FeatureInfo(net.feature_info, out_indices)
else:
assert False, "Provided feature_info is not valid"
def _get_return_layers(feature_info, out_map):
module_names = feature_info.module_name()
return_layers = {}
for i, name in enumerate(module_names):
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
return return_layers
class FeatureDictNet(nn.ModuleDict):
""" Feature extractor with OrderedDict return
Wrap a model and extract features as specified by the out indices, the network is
partially re-built from contained modules.
There is a strong assumption that the modules have been registered into the model in the same
order as they are used. There should be no reuse of the same nn.Module more than once, including
trivial modules like `self.relu = nn.ReLU`.
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
Arguments:
model (nn.Module): model from which we will extract the features
out_indices (tuple[int]): model output indices to extract features for
out_map (sequence): list or tuple specifying desired return id for each out index,
otherwise str(index) is used
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model
"""
def __init__(
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
super(FeatureDictNet, self).__init__()
self.feature_info = _get_feature_info(model, out_indices)
self.concat = feature_concat
self.return_layers = {}
return_layers = _get_return_layers(self.feature_info, out_map)
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys())
layers = OrderedDict()
for new_name, old_name, module in modules:
layers[new_name] = module
if old_name in remaining:
# return id has to be consistently str type for torchscript
self.return_layers[new_name] = str(return_layers[old_name])
remaining.remove(old_name)
if not remaining:
break
assert not remaining and len(self.return_layers) == len(return_layers), \
f'Return layers ({remaining}) are not present in model'
self.update(layers)
def _collect(self, x) -> (Dict[str, torch.Tensor]):
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_id = self.return_layers[name]
if isinstance(x, (tuple, list)):
# If model tap is a tuple or list, concat or select first element
# FIXME this may need to be more generic / flexible for some nets
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
else:
out[out_id] = x
return out
def forward(self, x) -> Dict[str, torch.Tensor]:
return self._collect(x)
class FeatureListNet(FeatureDictNet):
""" Feature extractor with list return
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
"""
def __init__(
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
super(FeatureListNet, self).__init__(
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
flatten_sequential=flatten_sequential)
def forward(self, x) -> (List[torch.Tensor]):
return list(self._collect(x).values())
class FeatureHookNet(nn.ModuleDict):
""" FeatureHookNet
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
network in any way.
If `no_rewrite` is False, the model will be re-written as in the
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
FIXME this does not currently work with Torchscript, see FeatureHooks class
"""
def __init__(
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
super(FeatureHookNet, self).__init__()
assert not torch.jit.is_scripting()
self.feature_info = _get_feature_info(model, out_indices)
self.out_as_dict = out_as_dict
layers = OrderedDict()
hooks = []
if no_rewrite:
assert not flatten_sequential
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
model.reset_classifier(0)
layers['body'] = model
hooks.extend(self.feature_info.get_dicts())
else:
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
for f in self.feature_info.get_dicts()}
for new_name, old_name, module in modules:
layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name):
if fn in remaining:
hooks.append(dict(module=fn, hook_type=remaining[fn]))
del remaining[fn]
if not remaining:
break
assert not remaining, f'Return layers ({remaining}) are not present in model'
self.update(layers)
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
def forward(self, x):
for name, module in self.items():
x = module(x)
out = self.hooks.get_output(x.device)
return out if self.out_as_dict else list(out.values())

@ -0,0 +1,110 @@
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable, List, Dict, Union, Type
import torch
from torch import nn
from ._features import _get_feature_info
try:
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
# Layers we went to treat as leaf modules
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from timm.layers.non_local_attn import BilinearAttnTransform
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
BilinearAttnTransform, # reason: flow control t <= 1
# Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
}
try:
from timm.layers import InplaceAbn
_leaf_modules.add(InplaceAbn)
except ImportError:
pass
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
'FeatureGraphNet', 'GraphExtractNet']
def register_notrace_module(module: Type[nn.Module]):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""
_leaf_modules.add(module)
return module
# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()
def register_notrace_function(func: Callable):
"""
Decorator for functions which ought not to be traced through
"""
_autowrap_functions.add(func)
return func
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
return _create_feature_extractor(
model, return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
)
class FeatureGraphNet(nn.Module):
""" A FX Graph based feature extractor that works with the model feature_info metadata
"""
def __init__(self, model, out_indices, out_map=None):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None:
assert len(out_map) == len(out_indices)
return_nodes = {
info['module']: out_map[i] if out_map is not None else info['module']
for i, info in enumerate(self.feature_info) if i in out_indices}
self.graph_module = create_feature_extractor(model, return_nodes)
def forward(self, x):
return list(self.graph_module(x).values())
class GraphExtractNet(nn.Module):
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
NOTE:
* one can use feature_extractor directly if dictionary output is desired
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
metadata for builtin feature extraction mode
* create_feature_extractor can be used directly if dictionary output is acceptable
Args:
model: model to extract features from
return_nodes: node names to return features from (dict or list)
squeeze_out: if only one output, and output in list format, flatten to single tensor
"""
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
super().__init__()
self.squeeze_out = squeeze_out
self.graph_module = create_feature_extractor(model, return_nodes)
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
out = list(self.graph_module(x).values())
if self.squeeze_out and len(out) == 1:
return out[0]
return out

@ -0,0 +1,115 @@
""" Model creation / weight loading / state_dict helpers
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import os
from collections import OrderedDict
import torch
import timm.models._builder
_logger = logging.getLogger(__name__)
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
def clean_state_dict(state_dict):
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
cleaned_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
cleaned_state_dict[name] = v
return cleaned_state_dict
def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = ''
if isinstance(checkpoint, dict):
if use_ema and checkpoint.get('state_dict_ema', None) is not None:
state_dict_key = 'state_dict_ema'
elif use_ema and checkpoint.get('model_ema', None) is not None:
state_dict_key = 'model_ema'
elif 'state_dict' in checkpoint:
state_dict_key = 'state_dict'
elif 'model' in checkpoint:
state_dict_key = 'model'
state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
return state_dict
else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'):
timm.models._model_builder.load_pretrained(checkpoint_path)
else:
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema)
if remap:
state_dict = remap_checkpoint(model, state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def remap_checkpoint(model, state_dict, allow_reshape=True):
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
This assumes models (and originating state dict) were created with params registered in same order.
"""
out_dict = {}
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
if va.shape != vb.shape:
if allow_reshape:
vb = vb.reshape(va.shape)
else:
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
out_dict[ka] = vb
return out_dict
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
if log_info:
_logger.info('Restoring model state from checkpoint...')
state_dict = clean_state_dict(checkpoint['state_dict'])
model.load_state_dict(state_dict)
if optimizer is not None and 'optimizer' in checkpoint:
if log_info:
_logger.info('Restoring optimizer state from checkpoint...')
optimizer.load_state_dict(checkpoint['optimizer'])
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
if log_info:
_logger.info('Restoring AMP loss scaler state from checkpoint...')
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
model.load_state_dict(checkpoint)
if log_info:
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
return resume_epoch
else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()

@ -0,0 +1,220 @@
import json
import logging
import os
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Union
import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
try:
from torch.hub import get_dir
except ImportError:
from torch.hub import _get_torch_home as get_dir
from timm import __version__
from timm.models._pretrained import filter_pretrained_cfg
try:
from huggingface_hub import (
create_repo, get_hf_file_metadata,
hf_hub_download, hf_hub_url,
repo_type_and_id_from_hf_id, upload_folder)
from huggingface_hub.utils import EntryNotFoundError
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
_has_hf_hub = True
except ImportError:
hf_hub_download = None
_has_hf_hub = False
_logger = logging.getLogger(__name__)
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
def get_cache_dir(child_dir=''):
"""
Returns the location of the directory where models are cached (and creates it if necessary).
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
child_dir = () if not child_dir else (child_dir,)
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
os.makedirs(model_dir, exist_ok=True)
return model_dir
def download_cached_file(url, check_hash=True, progress=False):
if isinstance(url, (list, tuple)):
url, filename = url
else:
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename)
if not os.path.exists(cached_file):
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
return cached_file
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub
def hf_split(hf_id):
# FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
rev_split = hf_id.split('@')
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
hf_model_id = rev_split[0]
hf_revision = rev_split[-1] if len(rev_split) > 1 else None
return hf_model_id, hf_revision
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
def _download_from_hf(model_id: str, filename: str):
hf_model_id, hf_revision = hf_split(model_id)
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
def load_model_config_from_hf(model_id: str):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'config.json')
hf_config = load_cfg_from_json(cached_file)
if 'pretrained_cfg' not in hf_config:
# old form, pull pretrain_cfg out of the base dict
pretrained_cfg = hf_config
hf_config = {}
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
if 'labels' in pretrained_cfg:
hf_config['label_name'] = pretrained_cfg.pop('labels')
hf_config['pretrained_cfg'] = pretrained_cfg
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
pretrained_cfg = hf_config['pretrained_cfg']
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
pretrained_cfg['source'] = 'hf-hub'
if 'num_classes' in hf_config:
# model should be created with parent num_classes if they exist
pretrained_cfg['num_classes'] = hf_config['num_classes']
model_name = hf_config['architecture']
return pretrained_cfg, model_name
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, filename)
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict
def save_for_hf(model, save_directory, model_config=None):
assert has_hf_hub(True)
model_config = model_config or {}
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin'
torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json'
hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
# set some values at root config level
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
hf_config['num_features'] = model_config.get('num_features', model.num_features)
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
if 'label' in model_config:
_logger.warning(
"'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
"Using provided 'label' field as 'label_name'.")
model_config['label_name'] = model_config.pop('label')
label_name = model_config.pop('label_name', None)
if label_name:
assert isinstance(label_name, (dict, list, tuple))
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
# can be a dict id: name if there are id gaps, or tuple/list if no gaps.
hf_config['label_name'] = model_config['label_name']
display_name = model_config.pop('display_name', None)
if display_name:
assert isinstance(display_name, dict)
# map label_name -> user interface display name
hf_config['display_name'] = model_config['display_name']
hf_config['pretrained_cfg'] = pretrained_cfg
hf_config.update(model_config)
with config_path.open('w') as f:
json.dump(hf_config, f, indent=2)
def push_to_hf_hub(
model,
repo_id: str,
commit_message: str = 'Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_config: Optional[dict] = None,
):
# Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
# Infer complete repo_id from repo_url
# Can be different from the input `repo_id` if repo_owner was implicit
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
repo_id = f"{repo_owner}/{repo_name}"
# Check if README file already exist in repo
try:
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
has_readme = True
except EntryNotFoundError:
has_readme = False
# Dump model and push to Hub
with TemporaryDirectory() as tmpdir:
# Save model weights and config.
save_for_hf(model, tmpdir, model_config=model_config)
# Add readme if it does not exist
if not has_readme:
model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md"
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}'
readme_path.write_text(readme_text)
# Upload model and return
return upload_folder(
repo_id=repo_id,
folder_path=tmpdir,
revision=revision,
create_pr=create_pr,
commit_message=commit_message,
)

@ -0,0 +1,258 @@
import collections.abc
import math
import re
from collections import defaultdict
from itertools import chain
from typing import Callable, Union, Dict
import torch
from torch import nn as nn
from torch.utils.checkpoint import checkpoint
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
def model_parameters(model, exclude_head=False):
if exclude_head:
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
return [p for p in model.parameters()][:-2]
else:
return model.parameters()
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
if not depth_first and include_root:
yield name, module
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
yield from named_modules(
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
yield name, module
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
if module._parameters and not depth_first and include_root:
yield name, module
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
yield from named_modules_with_params(
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if module._parameters and depth_first and include_root:
yield name, module
MATCH_PREV_GROUP = (99999,)
def group_with_matcher(
named_objects,
group_matcher: Union[Dict, Callable],
output_values: bool = False,
reverse: bool = False
):
if isinstance(group_matcher, dict):
# dictionary matcher contains a dict of raw-string regex expr that must be compiled
compiled = []
for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
if mspec is None:
continue
# map all matching specifications into 3-tuple (compiled re, prefix, suffix)
if isinstance(mspec, (tuple, list)):
# multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
for sspec in mspec:
compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
else:
compiled += [(re.compile(mspec), (group_ordinal,), None)]
group_matcher = compiled
def _get_grouping(name):
if isinstance(group_matcher, (list, tuple)):
for match_fn, prefix, suffix in group_matcher:
r = match_fn.match(name)
if r:
parts = (prefix, r.groups(), suffix)
# map all tuple elem to int for numeric sort, filter out None entries
return tuple(map(float, chain.from_iterable(filter(None, parts))))
return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal
else:
ord = group_matcher(name)
if not isinstance(ord, collections.abc.Iterable):
return ord,
return tuple(ord)
# map layers into groups via ordinals (ints or tuples of ints) from matcher
grouping = defaultdict(list)
for k, v in named_objects:
grouping[_get_grouping(k)].append(v if output_values else k)
# remap to integers
layer_id_to_param = defaultdict(list)
lid = -1
for k in sorted(filter(lambda x: x is not None, grouping.keys())):
if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
lid += 1
layer_id_to_param[lid].extend(grouping[k])
if reverse:
assert not output_values, "reverse mapping only sensible for name output"
# output reverse mapping
param_to_layer_id = {}
for lid, lm in layer_id_to_param.items():
for n in lm:
param_to_layer_id[n] = lid
return param_to_layer_id
return layer_id_to_param
def group_parameters(
module: nn.Module,
group_matcher,
output_values=False,
reverse=False,
):
return group_with_matcher(
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
def group_modules(
module: nn.Module,
group_matcher,
output_values=False,
reverse=False,
):
return group_with_matcher(
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
prefix_is_tuple = isinstance(prefix, tuple)
if isinstance(module_types, str):
if module_types == 'container':
module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
else:
module_types = (nn.Sequential,)
for name, module in named_modules:
if depth and isinstance(module, module_types):
yield from flatten_modules(
module.named_children(),
depth - 1,
prefix=(name,) if prefix_is_tuple else name,
module_types=module_types,
)
else:
if prefix_is_tuple:
name = prefix + (name,)
yield name, module
else:
if prefix:
name = '.'.join([prefix, name])
yield name, module
def checkpoint_seq(
functions,
x,
every=1,
flatten=False,
skip_last=False,
preserve_rng_state=True
):
r"""A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order
(sequentially). Therefore, we can divide such a sequence into segments
and checkpoint each segment. All segments except run in :func:`torch.no_grad`
manner, i.e., not storing the intermediate activations. The inputs of each
checkpointed segment will be saved for re-running the segment in the backward pass.
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
.. warning::
Checkpointing currently only supports :func:`torch.autograd.backward`
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
is not supported.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
Args:
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
x: A Tensor that is input to :attr:`functions`
every: checkpoint every-n functions (default: 1)
flatten (bool): flatten nn.Sequential of nn.Sequentials
skip_last (bool): skip checkpointing the last function in the sequence if True
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
Example:
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_seq(model, input_var, every=2)
"""
def run_function(start, end, functions):
def forward(_x):
for j in range(start, end + 1):
_x = functions[j](_x)
return _x
return forward
if isinstance(functions, torch.nn.Sequential):
functions = functions.children()
if flatten:
functions = chain.from_iterable(functions)
if not isinstance(functions, (tuple, list)):
functions = tuple(functions)
num_checkpointed = len(functions)
if skip_last:
num_checkpointed -= 1
end = -1
for start in range(0, num_checkpointed, every):
end = min(start + every - 1, num_checkpointed - 1)
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
if skip_last:
return run_function(end + 1, len(functions) - 1, functions)(x)
return x
def adapt_input_conv(in_chans, conv_weight):
conv_type = conv_weight.dtype
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape
if in_chans == 1:
if I > 3:
assert conv_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
conv_weight = conv_weight.sum(dim=2, keepdim=False)
else:
conv_weight = conv_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:
if I != 3:
raise NotImplementedError('Weight format not supported by conversion.')
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
repeat = int(math.ceil(in_chans / 3))
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv_weight *= (3 / float(in_chans))
conv_weight = conv_weight.to(conv_type)
return conv_weight

@ -4,6 +4,9 @@ from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union from typing import Any, Deque, Dict, Tuple, Optional, Union
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
@dataclass @dataclass
class PretrainedCfg: class PretrainedCfg:
""" """

@ -0,0 +1,113 @@
import os
from copy import deepcopy
from torch import nn as nn
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
def extract_layer(model, layer):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
if not hasattr(model, 'module') and layer[0] == 'module':
layer = layer[1:]
for l in layer:
if hasattr(module, l):
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
else:
return module
return module
def set_layer(model, layer, val):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
lst_index = 0
module2 = module
for l in layer:
if hasattr(module2, l):
if not l.isdigit():
module2 = getattr(module2, l)
else:
module2 = module2[int(l)]
lst_index += 1
lst_index -= 1
for l in layer[:lst_index]:
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
l = layer[lst_index]
setattr(module, l, val)
def adapt_model_from_string(parent_module, model_string):
separator = '***'
state_dict = {}
lst_shape = model_string.split(separator)
for k in lst_shape:
k = k.split(':')
key = k[0]
shape = k[1][1:-1].split(',')
if shape[0] != '':
state_dict[key] = [int(i) for i in shape]
new_module = deepcopy(parent_module)
for n, m in parent_module.named_modules():
old_module = extract_layer(parent_module, n)
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
if isinstance(old_module, Conv2dSame):
conv = Conv2dSame
else:
conv = nn.Conv2d
s = state_dict[n + '.weight']
in_channels = s[1]
out_channels = s[0]
g = 1
if old_module.groups > 1:
in_channels = out_channels
g = in_channels
new_conv = conv(
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
groups=g, stride=old_module.stride)
set_layer(new_module, n, new_conv)
elif isinstance(old_module, BatchNormAct2d):
new_bn = BatchNormAct2d(
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
new_bn.drop = old_module.drop
new_bn.act = old_module.act
set_layer(new_module, n, new_bn)
elif isinstance(old_module, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d(
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
set_layer(new_module, n, new_bn)
elif isinstance(old_module, nn.Linear):
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
num_features = state_dict[n + '.weight'][1]
new_fc = Linear(
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
set_layer(new_module, n, new_fc)
if hasattr(new_module, 'num_features'):
new_module.num_features = num_features
new_module.eval()
parent_module.eval()
return new_module
def adapt_model_from_file(parent_module, model_variant):
adapt_file = os.path.join(os.path.dirname(__file__), '_pruned', model_variant + '.txt')
with open(adapt_file, 'r') as f:
return adapt_model_from_string(parent_module, f.read().strip())

@ -0,0 +1,212 @@
""" Model Registry
Hacked together by / Copyright 2020 Ross Wightman
"""
import fnmatch
import re
import sys
from collections import defaultdict, deque
from copy import deepcopy
from typing import List, Optional, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
__all__ = [
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
return split_model_name_tag(model_name)[0]
def register_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
# add model to __all__ in module
model_name = fn.__name__
if hasattr(mod, '__all__'):
mod.__all__.append(model_name)
else:
mod.__all__ = [model_name]
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
cfg = mod.default_cfgs[model_name]
if not isinstance(cfg, DefaultCfg):
# new style default cfg dataclass w/ multiple entries per model-arch
assert isinstance(cfg, dict)
# old style cfg dict per model-arch
cfg = PretrainedCfg(**cfg)
cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg})
for tag_idx, tag in enumerate(cfg.tags):
is_default = tag_idx == 0
pretrained_cfg = cfg.cfgs[tag]
if is_default:
_model_pretrained_cfgs[model_name] = pretrained_cfg
if pretrained_cfg.has_weights:
# add tagless entry if it's default and has weights
_model_has_pretrained.add(model_name)
if tag:
model_name_tag = '.'.join([model_name, tag])
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
if pretrained_cfg.has_weights:
# add model w/ tag if tag is valid
_model_has_pretrained.add(model_name_tag)
_model_with_tags[model_name].append(model_name_tag)
else:
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
_model_default_cfgs[model_name] = cfg
return fn
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(
filter: Union[str, List[str]] = '',
module: str = '',
pretrained=False,
exclude_filters: str = '',
name_matches_cfg: bool = False,
include_tags: Optional[bool] = None,
):
""" Return list of available model names, sorted alphabetically
Args:
filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
pretrained (bool) - Include only models with valid pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
set to True when pretrained=True else False (default: None)
Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
"""
if include_tags is None:
# FIXME should this be default behaviour? or default to include_tags=True?
include_tags = pretrained
if module:
all_models = list(_module_to_models[module])
else:
all_models = _model_entrypoints.keys()
if include_tags:
# expand model names to include names w/ pretrained tags
models_with_tags = []
for m in all_models:
models_with_tags.extend(_model_with_tags[m])
all_models = models_with_tags
if filter:
models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models):
models = set(models).union(include_models)
else:
models = all_models
if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters]
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models):
models = set(models).difference(exclude_models)
if pretrained:
models = _model_has_pretrained.intersection(models)
if name_matches_cfg:
models = set(_model_pretrained_cfgs).intersection(models)
return list(sorted(models, key=_natural_key))
def list_pretrained(
filter: Union[str, List[str]] = '',
exclude_filters: str = '',
):
return list_models(
filter=filter,
pretrained=True,
exclude_filters=exclude_filters,
include_tags=True,
)
def is_model(model_name):
""" Check if a model name exists
"""
arch_name = get_arch_name(model_name)
return arch_name in _model_entrypoints
def model_entrypoint(model_name, module_filter: Optional[str] = None):
"""Fetch a model entrypoint for specified model name
"""
arch_name = get_arch_name(model_name)
if module_filter and arch_name not in _module_to_models.get(module_filter, {}):
raise RuntimeError(f'Model ({model_name} not found in module {module_filter}.')
return _model_entrypoints[arch_name]
def list_modules():
""" Return list of module names that contain models / model entrypoints
"""
modules = _module_to_models.keys()
return list(sorted(modules))
def is_model_in_modules(model_name, module_names):
"""Check if a model exists within a subset of modules
Args:
model_name (str) - name of model to check
module_names (tuple, list, set) - names of modules to search in
"""
arch_name = get_arch_name(model_name)
assert isinstance(module_names, (tuple, list, set))
return any(arch_name in _module_to_models[n] for n in module_names)
def is_model_pretrained(model_name):
return model_name in _model_has_pretrained
def get_pretrained_cfg(model_name):
if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name])
raise RuntimeError(f'No pretrained config exists for model {model_name}.')
def get_pretrained_cfg_value(model_name, cfg_key):
""" Get a specific model default_cfg value by key. None if key doesn't exist.
"""
if model_name in _model_pretrained_cfgs:
return getattr(_model_pretrained_cfgs[model_name], cfg_key, None)
raise RuntimeError(f'No pretrained config exist for model {model_name}.')

@ -61,12 +61,14 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from .helpers import build_model_with_cfg from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from ._builder import build_model_with_cfg
from .pretrained import generate_default_cfgs from ._pretrained import generate_default_cfgs
from .registry import register_model from ._registry import register_model
from .vision_transformer import checkpoint_filter_fn from .vision_transformer import checkpoint_filter_fn
__all__ = ['Beit']
def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3

@ -13,9 +13,9 @@ Consider all of the models definitions here as experimental WIP and likely to ch
Hacked together by / copyright Ross Wightman, 2021. Hacked together by / copyright Ross Wightman, 2021.
""" """
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._builder import build_model_with_cfg
from ._registry import register_model
from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
from .helpers import build_model_with_cfg
from .registry import register_model
__all__ = [] __all__ = []

@ -26,18 +26,18 @@ Hacked together by / copyright Ross Wightman, 2021.
""" """
import math import math
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
from functools import partial from functools import partial
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply, checkpoint_seq from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\ from ._builder import build_model_with_cfg
EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d from ._manipulate import named_apply, checkpoint_seq
from .registry import register_model from ._registry import register_model
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']

@ -8,17 +8,16 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
""" """
# Copyright (c) 2015-present, Facebook, Inc. # Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
from copy import deepcopy
from functools import partial from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, checkpoint_seq from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from ._builder import build_model_with_cfg
from .registry import register_model from ._manipulate import checkpoint_seq
from ._registry import register_model
__all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn'] __all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn']

@ -7,7 +7,6 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
Modified from timm/models/vision_transformer.py Modified from timm/models/vision_transformer.py
""" """
from copy import deepcopy
from functools import partial from functools import partial
from typing import Tuple, List, Union from typing import Tuple, List, Union
@ -16,19 +15,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from ._builder import build_model_with_cfg
from .registry import register_model from ._registry import register_model
from .layers import _assert
__all__ = ['CoaT']
__all__ = [
"coat_tiny",
"coat_mini",
"coat_lite_tiny",
"coat_lite_mini",
"coat_lite_small"
]
def _cfg_coat(url='', **kwargs): def _cfg_coat(url='', **kwargs):

@ -22,20 +22,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
''' '''
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp from ._builder import build_model_with_cfg
from .registry import register_model from ._features_fx import register_notrace_module
from ._registry import register_model
from .vision_transformer_hybrid import HybridEmbed from .vision_transformer_hybrid import HybridEmbed
from .fx_features import register_notrace_module
import torch
import torch.nn as nn __all__ = ['ConViT']
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):

@ -5,9 +5,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.registry import register_model from timm.layers import SelectAdaptivePool2d
from .helpers import build_model_with_cfg, checkpoint_seq from ._registry import register_model
from .layers import SelectAdaptivePool2d from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
__all__ = ['ConvMixer']
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):

@ -18,12 +18,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
create_conv2d, get_act_layer, make_divisible, to_ntuple create_conv2d, get_act_layer, make_divisible, to_ntuple
from .pretrained import generate_default_cfgs from ._builder import build_model_with_cfg
from .registry import register_model from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this

@ -24,21 +24,22 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
""" """
from functools import partial
from typing import List
from typing import Tuple from typing import Tuple
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.hub import torch.hub
from functools import partial import torch.nn as nn
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert
from .helpers import build_model_with_cfg from ._builder import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, _assert from ._features_fx import register_notrace_function
from .registry import register_model from ._registry import register_model
from .vision_transformer import Mlp, Block from .vision_transformer import Block
__all__ = ['CrossViT'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):

@ -12,20 +12,18 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import collections.abc from dataclasses import dataclass, asdict
from dataclasses import dataclass, field, asdict
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible from ._builder import build_model_with_cfg
from .registry import register_model from ._manipulate import named_apply, MATCH_PREV_GROUP
from ._registry import register_model
__all__ = ['CspNet'] # model_registry will add each entrypoint fn to this __all__ = ['CspNet'] # model_registry will add each entrypoint fn to this

@ -17,9 +17,11 @@ from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model
from .helpers import build_model_with_cfg, checkpoint_seq __all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
from .registry import register_model
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):

@ -4,7 +4,6 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool.
""" """
import re import re
from collections import OrderedDict from collections import OrderedDict
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -13,9 +12,10 @@ import torch.utils.checkpoint as cp
from torch.jit.annotations import List from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, MATCH_PREV_GROUP from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier from ._builder import build_model_with_cfg
from .registry import register_model from ._manipulate import MATCH_PREV_GROUP
from ._registry import register_model
__all__ = ['DenseNet'] __all__ = ['DenseNet']

@ -13,9 +13,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from timm.layers import create_classifier
from .layers import create_classifier from ._builder import build_model_with_cfg
from .registry import register_model from ._registry import register_model
__all__ = ['DLA'] __all__ = ['DLA']

@ -15,9 +15,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier
from .layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier from ._builder import build_model_with_cfg
from .registry import register_model from ._registry import register_model
__all__ = ['DPN'] __all__ = ['DPN']

@ -8,20 +8,20 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt
Modifications and additions for timm by / Copyright 2022, Ross Wightman Modifications and additions for timm by / Copyright 2022, Ross Wightman
""" """
import math import math
import torch
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Tuple from typing import Tuple
from torch import nn import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d from ._builder import build_model_with_cfg
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from ._features_fx import register_notrace_module
from .registry import register_model from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model
__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this __all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this

@ -18,9 +18,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
from .layers import DropPath, trunc_normal_, to_2tuple, Mlp from ._builder import build_model_with_cfg
from .registry import register_model from ._registry import register_model
__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):

@ -42,15 +42,15 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import SqueezeExcite from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ from ._builder import build_model_with_cfg, pretrained_cfg_for_features
from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks from ._features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq from ._manipulate import checkpoint_seq
from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct from ._registry import register_model
from .registry import register_model
__all__ = ['EfficientNet', 'EfficientNetFeatures'] __all__ = ['EfficientNet', 'EfficientNetFeatures']

@ -1,100 +1,4 @@
import os from ._factory import *
from typing import Any, Dict, Optional, Union
from urllib.parse import urlsplit
from .pretrained import PretrainedCfg, split_model_name_tag import warnings
from .helpers import load_checkpoint warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
from .hub import load_model_config_from_hf
from .layers import set_layer_config
from .registry import is_model, model_entrypoint
def parse_model_name(model_name):
if model_name.startswith('hf_hub'):
# NOTE for backwards compat, deprecate hf_hub use
model_name = model_name.replace('hf_hub', 'hf-hub')
parsed = urlsplit(model_name)
assert parsed.scheme in ('', 'timm', 'hf-hub')
if parsed.scheme == 'hf-hub':
# FIXME may use fragment as revision, currently `@` in URI path
return parsed.scheme, parsed.path
else:
model_name = os.path.split(parsed.path)[-1]
return 'timm', model_name
def safe_model_name(model_name, remove_source=True):
# return a filename / path safe model name
def make_safe(name):
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
if remove_source:
model_name = parse_model_name(model_name)[-1]
return make_safe(model_name)
def create_model(
model_name: str,
pretrained: bool = False,
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
checkpoint_path: str = '',
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
**kwargs,
):
"""Create a model
Lookup model's entrypoint function and pass relevant args to create a new model.
**kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg()
and then the model class __init__(). kwargs values set to None are pruned before passing.
Args:
model_name (str): name of model to instantiate
pretrained (bool): load pretrained ImageNet-1k weights if true
pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model
pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these
checkpoint_path (str): path of checkpoint to load _after_ the model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
Keyword Args:
drop_rate (float): dropout rate for training (default: 0.0)
global_pool (str): global pool type (default: 'avg')
**: other kwargs are consumed by builder or model __init__()
"""
# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
model_source, model_name = parse_model_name(model_name)
if model_source == 'hf-hub':
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
else:
model_name, pretrained_tag = split_model_name_tag(model_name)
if not pretrained_cfg:
# a valid pretrained_cfg argument takes priority over tag in model name
pretrained_cfg = pretrained_tag
if not is_model(model_name):
raise RuntimeError('Unknown model (%s)' % model_name)
create_fn = model_entrypoint(model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(
pretrained=pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay,
**kwargs,
)
if checkpoint_path:
load_checkpoint(model, checkpoint_path)
return model

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save