diff --git a/README.md b/README.md index 994775f1..bb6485c0 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,25 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New -# Dec 6, 2022 +### 🤗 Survey: Feedback Appreciated 🤗 + +For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing. + +If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts: +[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏 + +### Dec 8, 2022 +* Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some) + * original source: https://github.com/baaivision/EVA + +| model | top1 | param_count | gmac | macts | hub | +|:------------------------------------------|-----:|------------:|------:|------:|:----------------------------------------| +| eva_large_patch14_336.in22k_ft_in22k_in1k | 89.2 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_336.in22k_ft_in1k | 88.7 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_196.in22k_ft_in22k_in1k | 88.6 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_196.in22k_ft_in1k | 87.9 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | + +### Dec 6, 2022 * Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`. * original source: https://github.com/baaivision/EVA * paper: https://arxiv.org/abs/2211.07636 @@ -33,7 +51,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before | eva_giant_patch14_336.clip_ft_in1k | 89.4 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) | | eva_giant_patch14_224.clip_ft_in1k | 89.1 | 1012.6 | 267.2 | 192.6 | [link](https://huggingface.co/BAAI/EVA) | -# Dec 5, 2022 +### Dec 5, 2022 * Pre-release (`0.8.0dev0`) of multi-weight support (`model_arch.pretrained_tag`). Install with `pip install --pre timm` * vision_transformer, maxvit, convnext are the first three model impl w/ support diff --git a/avg_checkpoints.py b/avg_checkpoints.py index ea8bbe84..83af5bbd 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -16,7 +16,7 @@ import argparse import os import glob import hashlib -from timm.models.helpers import load_state_dict +from timm.models import load_state_dict parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser.add_argument('--input', default='', type=str, metavar='PATH', diff --git a/benchmark.py b/benchmark.py index 9adeb465..58435ff8 100755 --- a/benchmark.py +++ b/benchmark.py @@ -19,7 +19,8 @@ import torch.nn as nn import torch.nn.parallel from timm.data import resolve_data_config -from timm.models import create_model, is_model, list_models, set_fast_norm +from timm.layers import set_fast_norm +from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry diff --git a/clean_checkpoint.py b/clean_checkpoint.py index 8ec892b2..17c270db 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -13,7 +13,7 @@ import os import hashlib import shutil from collections import OrderedDict -from timm.models.helpers import load_state_dict +from timm.models import load_state_dict parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', diff --git a/hubconf.py b/hubconf.py index 70fed79a..6b2061ea 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,4 +1,3 @@ dependencies = ['torch'] -from timm.models import registry - -globals().update(registry._model_entrypoints) +import timm +globals().update(timm.models._registry._model_entrypoints) diff --git a/inference.py b/inference.py index bc794840..1509b323 100755 --- a/inference.py +++ b/inference.py @@ -5,11 +5,11 @@ An example inference script that outputs top-k class ids for images in a folder Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ -import os -import time import argparse import json import logging +import os +import time from contextlib import suppress from functools import partial @@ -17,12 +17,11 @@ import numpy as np import pandas as pd import torch -from timm.models import create_model, apply_test_time_pool, load_checkpoint from timm.data import create_dataset, create_loader, resolve_data_config +from timm.layers import apply_test_time_pool +from timm.models import create_model from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser - - try: from apex import amp has_apex = True diff --git a/tests/test_layers.py b/tests/test_layers.py index 508a6aae..da061870 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,10 +1,7 @@ -import pytest import torch import torch.nn as nn -import platform -import os -from timm.models.layers import create_act_layer, get_act_layer, set_layer_config +from timm.layers import create_act_layer, set_layer_config class MLP(nn.Module): diff --git a/tests/test_models.py b/tests/test_models.py index 008d87b7..4c848440 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,7 +14,7 @@ except ImportError: import timm from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value -from timm.models.fx_features import _leaf_modules, _autowrap_functions +from timm.models._features_fx import _leaf_modules, _autowrap_functions if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests diff --git a/timm/__init__.py b/timm/__init__.py index faf34dbc..3d38cdb9 100644 --- a/timm/__init__.py +++ b/timm/__init__.py @@ -1,4 +1,4 @@ from .version import __version__ +from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \ - is_scriptable, is_exportable, set_scriptable, set_exportable, \ is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 1b51ccb4..a7701b82 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -1,4 +1,4 @@ -""" AutoAugment, RandAugment, and AugMix for PyTorch +""" AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch This code implements the searched ImageNet policies with various tweaks and improvements and does not include any of the search code. @@ -9,18 +9,24 @@ AA and RA Implementation adapted from: AugMix adapted from: https://github.com/google-research/augmix +3-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md + Papers: AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 + 3-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118 Hacked together by / Copyright 2019, Ross Wightman """ import random import math import re -from PIL import Image, ImageOps, ImageEnhance, ImageChops +from functools import partial +from typing import Dict, List, Optional, Union + +from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter import PIL import numpy as np @@ -175,6 +181,24 @@ def sharpness(img, factor, **__): return ImageEnhance.Sharpness(img).enhance(factor) +def gaussian_blur(img, factor, **__): + img = img.filter(ImageFilter.GaussianBlur(radius=factor)) + return img + + +def gaussian_blur_rand(img, factor, **__): + radius_min = 0.1 + radius_max = 2.0 + img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor))) + return img + + +def desaturate(img, factor, **_): + factor = min(1., max(0., 1. - factor)) + # enhance factor 0 = grayscale, 1.0 = no-change + return ImageEnhance.Color(img).enhance(factor) + + def _randomly_negate(v): """With 50% prob, negate the value""" return -v if random.random() > 0.5 else v @@ -200,6 +224,14 @@ def _enhance_increasing_level_to_arg(level, _hparams): return level, +def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True): + level = (level / _LEVEL_DENOM) + min_val + (max_val - min_val) * level + if clamp: + level = max(min_val, min(max_val, level)) + return level, + + def _shear_level_to_arg(level, _hparams): # range [-0.3, 0.3] level = (level / _LEVEL_DENOM) * 0.3 @@ -246,7 +278,7 @@ def _posterize_original_level_to_arg(level, _hparams): def _solarize_level_to_arg(level, _hparams): # range [0, 256] # intensity/severity of augmentation decreases with level - return int((level / _LEVEL_DENOM) * 256), + return min(256, int((level / _LEVEL_DENOM) * 256)), def _solarize_increasing_level_to_arg(level, _hparams): @@ -257,7 +289,7 @@ def _solarize_increasing_level_to_arg(level, _hparams): def _solarize_add_level_to_arg(level, _hparams): # range [0, 110] - return int((level / _LEVEL_DENOM) * 110), + return min(128, int((level / _LEVEL_DENOM) * 110)), LEVEL_TO_ARG = { @@ -286,6 +318,9 @@ LEVEL_TO_ARG = { 'TranslateY': _translate_abs_level_to_arg, 'TranslateXRel': _translate_rel_level_to_arg, 'TranslateYRel': _translate_rel_level_to_arg, + 'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0), + 'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0), + 'GaussianBlurRand': _minmax_level_to_arg, } @@ -314,6 +349,9 @@ NAME_TO_OP = { 'TranslateY': translate_y_abs, 'TranslateXRel': translate_x_rel, 'TranslateYRel': translate_y_rel, + 'Desaturate': desaturate, + 'GaussianBlur': gaussian_blur, + 'GaussianBlurRand': gaussian_blur_rand, } @@ -347,6 +385,7 @@ class AugmentOp: if self.magnitude_std > 0: # magnitude randomization enabled if self.magnitude_std == float('inf'): + # inf == uniform sampling magnitude = random.uniform(0, magnitude) elif self.magnitude_std > 0: magnitude = random.gauss(magnitude, self.magnitude_std) @@ -499,6 +538,16 @@ def auto_augment_policy_originalr(hparams): return pc +def auto_augment_policy_3a(hparams): + policy = [ + [('Solarize', 1.0, 5)], # 128 solarize threshold @ 5 magnitude + [('Desaturate', 1.0, 10)], # grayscale at 10 magnitude + [('GaussianBlurRand', 1.0, 10)], + ] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + def auto_augment_policy(name='v0', hparams=None): hparams = hparams or _HPARAMS_DEFAULT if name == 'original': @@ -509,6 +558,8 @@ def auto_augment_policy(name='v0', hparams=None): return auto_augment_policy_v0(hparams) elif name == 'v0r': return auto_augment_policy_v0r(hparams) + elif name == '3a': + return auto_augment_policy_3a(hparams) else: assert False, 'Unknown AA policy (%s)' % name @@ -534,19 +585,23 @@ class AutoAugment: return fs -def auto_augment_transform(config_str, hparams): +def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None): """ Create a AutoAugment transform - :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by - dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). - The remaining sections, not order sepecific determine - 'mstd' - float std deviation of magnitude noise applied - Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 + Args: + config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by + dashes ('-'). + The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). - :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme + The remaining sections: + 'mstd' - float std deviation of magnitude noise applied + Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 - :return: A PyTorch compatible Transform + hparams: Other hparams (kwargs) for the AutoAugmentation scheme + + Returns: + A PyTorch compatible Transform """ config = config_str.split('-') policy_name = config[0] @@ -605,42 +660,80 @@ _RAND_INCREASING_TRANSFORMS = [ ] +_RAND_3A = [ + 'SolarizeIncreasing', + 'Desaturate', + 'GaussianBlur', +] + + +_RAND_CHOICE_3A = { + 'SolarizeIncreasing': 6, + 'Desaturate': 6, + 'GaussianBlur': 6, + 'Rotate': 3, + 'ShearX': 2, + 'ShearY': 2, + 'PosterizeIncreasing': 1, + 'AutoContrast': 1, + 'ColorIncreasing': 1, + 'SharpnessIncreasing': 1, + 'ContrastIncreasing': 1, + 'BrightnessIncreasing': 1, + 'Equalize': 1, + 'Invert': 1, +} + # These experimental weights are based loosely on the relative improvements mentioned in paper. # They may not result in increased performance, but could likely be tuned to so. _RAND_CHOICE_WEIGHTS_0 = { - 'Rotate': 0.3, - 'ShearX': 0.2, - 'ShearY': 0.2, - 'TranslateXRel': 0.1, - 'TranslateYRel': 0.1, - 'Color': .025, - 'Sharpness': 0.025, - 'AutoContrast': 0.025, - 'Solarize': .005, - 'SolarizeAdd': .005, - 'Contrast': .005, - 'Brightness': .005, - 'Equalize': .005, - 'Posterize': 0, - 'Invert': 0, + 'Rotate': 3, + 'ShearX': 2, + 'ShearY': 2, + 'TranslateXRel': 1, + 'TranslateYRel': 1, + 'ColorIncreasing': .25, + 'SharpnessIncreasing': 0.25, + 'AutoContrast': 0.25, + 'SolarizeIncreasing': .05, + 'SolarizeAdd': .05, + 'ContrastIncreasing': .05, + 'BrightnessIncreasing': .05, + 'Equalize': .05, + 'PosterizeIncreasing': 0.05, + 'Invert': 0.05, } -def _select_rand_weights(weight_idx=0, transforms=None): - transforms = transforms or _RAND_TRANSFORMS - assert weight_idx == 0 # only one set of weights currently - rand_weights = _RAND_CHOICE_WEIGHTS_0 - probs = [rand_weights[k] for k in transforms] - probs /= np.sum(probs) - return probs +def _get_weighted_transforms(transforms: Dict): + transforms, probs = list(zip(*transforms.items())) + probs = np.array(probs) + probs = probs / np.sum(probs) + return transforms, probs + +def rand_augment_choices(name: str, increasing=True): + if name == 'weights': + return _RAND_CHOICE_WEIGHTS_0 + elif name == '3aw': + return _RAND_CHOICE_3A + elif name == '3a': + return _RAND_3A + else: + return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS -def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + +def rand_augment_ops( + magnitude: Union[int, float] = 10, + prob: float = 0.5, + hparams: Optional[Dict] = None, + transforms: Optional[Union[Dict, List]] = None, +): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS return [AugmentOp( - name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms] class RandAugment: @@ -648,11 +741,16 @@ class RandAugment: self.ops = ops self.num_layers = num_layers self.choice_weights = choice_weights + print(self.ops, self.choice_weights) def __call__(self, img): # no replacement when using weighted choice ops = np.random.choice( - self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) + self.ops, + self.num_layers, + replace=self.choice_weights is None, + p=self.choice_weights, + ) for op in ops: img = op(img) return img @@ -665,61 +763,84 @@ class RandAugment: return fs -def rand_augment_transform(config_str, hparams): +def rand_augment_transform( + config_str: str, + hparams: Optional[Dict] = None, + transforms: Optional[Union[str, Dict, List]] = None, +): """ Create a RandAugment transform - :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by - dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining - sections, not order sepecific determine - 'm' - integer magnitude of rand augment - 'n' - integer num layers (number of transform ops selected per image) - 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) - 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) - 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) - 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) - Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 - 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 - - :param hparams: Other hparams (kwargs) for the RandAugmentation scheme - - :return: A PyTorch compatible Transform + Args: + config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated + by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). + The remaining sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'p' - float probability of applying each layer (default 0.5) + 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) + 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + 't' - str name of transform set to use + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2 + + hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme + + Returns: + A PyTorch compatible Transform """ magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10) num_layers = 2 # default to 2 ops per image - weight_idx = None # default to no probability weights for op choice - transforms = _RAND_TRANSFORMS + increasing = False + prob = 0.5 config = config_str.split('-') assert config[0] == 'rand' config = config[1:] for c in config: - cs = re.split(r'(\d.*)', c) - if len(cs) < 2: - continue - key, val = cs[:2] - if key == 'mstd': - # noise param / randomization of magnitude values - mstd = float(val) - if mstd > 100: - # use uniform sampling in 0 to magnitude if mstd is > 100 - mstd = float('inf') - hparams.setdefault('magnitude_std', mstd) - elif key == 'mmax': - # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM] - hparams.setdefault('magnitude_max', int(val)) - elif key == 'inc': - if bool(val): - transforms = _RAND_INCREASING_TRANSFORMS - elif key == 'm': - magnitude = int(val) - elif key == 'n': - num_layers = int(val) - elif key == 'w': - weight_idx = int(val) + if c.startswith('t'): + # NOTE old 'w' key was removed, 'w0' is not equivalent to 'tweights' + val = str(c[1:]) + if transforms is None: + transforms = val else: - assert False, 'Unknown RandAugment config section' - ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms) - choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + # numeric options + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param / randomization of magnitude values + mstd = float(val) + if mstd > 100: + # use uniform sampling in 0 to magnitude if mstd is > 100 + mstd = float('inf') + hparams.setdefault('magnitude_std', mstd) + elif key == 'mmax': + # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM] + hparams.setdefault('magnitude_max', int(val)) + elif key == 'inc': + if bool(val): + increasing = True + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + elif key == 'p': + prob = float(val) + else: + assert False, 'Unknown RandAugment config section' + + if isinstance(transforms, str): + transforms = rand_augment_choices(transforms, increasing=increasing) + elif transforms is None: + transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS + + choice_weights = None + if isinstance(transforms, Dict): + transforms, choice_weights = _get_weighted_transforms(transforms) + + ra_ops = rand_augment_ops(magnitude=magnitude, prob=prob, hparams=hparams, transforms=transforms) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) @@ -740,11 +861,19 @@ _AUGMIX_TRANSFORMS = [ ] -def augmix_ops(magnitude=10, hparams=None, transforms=None): +def augmix_ops( + magnitude: Union[int, float] = 10, + hparams: Optional[Dict] = None, + transforms: Optional[Union[str, Dict, List]] = None, +): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _AUGMIX_TRANSFORMS return [AugmentOp( - name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] + name, + prob=1.0, + magnitude=magnitude, + hparams=hparams + ) for name in transforms] class AugMixAugment: @@ -820,22 +949,24 @@ class AugMixAugment: return fs -def augment_and_mix_transform(config_str, hparams): +def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None): """ Create AugMix PyTorch transform - :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by - dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining - sections, not order sepecific determine - 'm' - integer magnitude (severity) of augmentation mix (default: 3) - 'w' - integer width of augmentation chain (default: 3) - 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) - 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) - 'mstd' - float std deviation of magnitude noise applied (default: 0) - Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 - - :param hparams: Other hparams (kwargs) for the Augmentation transforms - - :return: A PyTorch compatible Transform + Args: + config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated + by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). + The remaining sections, not order sepecific determine + 'm' - integer magnitude (severity) of augmentation mix (default: 3) + 'w' - integer width of augmentation chain (default: 3) + 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) + 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) + 'mstd' - float std deviation of magnitude noise applied (default: 0) + Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 + + hparams: Other hparams (kwargs) for the Augmentation transforms + + Returns: + A PyTorch compatible Transform """ magnitude = 3 width = 3 diff --git a/timm/data/readers/class_map.py b/timm/data/readers/class_map.py index 6cf3f57e..885be6e2 100644 --- a/timm/data/readers/class_map.py +++ b/timm/data/readers/class_map.py @@ -1,6 +1,7 @@ import os import pickle + def load_class_map(map_or_filename, root=''): if isinstance(map_or_filename, dict): assert dict, 'class_map dict must be non-empty' @@ -14,7 +15,7 @@ def load_class_map(map_or_filename, root=''): with open(class_map_path) as f: class_to_idx = {v.strip(): k for k, v in enumerate(f)} elif class_map_ext == '.pkl': - with open(class_map_path,'rb') as f: + with open(class_map_path, 'rb') as f: class_to_idx = pickle.load(f) else: assert False, f'Unsupported class map file extension ({class_map_ext}).' diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 6c28383a..7749b206 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -59,6 +59,7 @@ def transforms_imagenet_train( re_count=1, re_num_splits=0, separate=False, + force_color_jitter=False, ): """ If separate==True, the transforms are returned as a tuple of 3 separate transforms @@ -77,8 +78,12 @@ def transforms_imagenet_train( primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] secondary_tfl = [] + disable_color_jitter = False if auto_augment: assert isinstance(auto_augment, str) + # color jitter is typically disabled if AA/RA on, + # this allows override without breaking old hparm cfgs + disable_color_jitter = not (force_color_jitter or '3a' in auto_augment) if isinstance(img_size, (tuple, list)): img_size_min = min(img_size) else: @@ -96,8 +101,9 @@ def transforms_imagenet_train( secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] else: secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] - elif color_jitter is not None: - # color jitter is enabled when not using AA + + if color_jitter is not None and not disable_color_jitter: + # color jitter is enabled when not using AA or when forced if isinstance(color_jitter, (list, tuple)): # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # or 4 if also augmenting hue diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py new file mode 100644 index 00000000..21c641b6 --- /dev/null +++ b/timm/layers/__init__.py @@ -0,0 +1,44 @@ +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .blur_pool import BlurPool2d +from .classifier import ClassifierHead, create_classifier +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ + set_layer_config +from .conv2d_same import Conv2dSame, conv2d_same +from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct +from .create_act import create_act_layer, get_act_layer, get_act_fn +from .create_attn import get_attn, create_attn +from .create_conv2d import create_conv2d +from .create_norm import get_norm_layer, create_norm_layer +from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ + EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm +from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple +from .inplace_abn import InplaceAbn +from .linear import Linear +from .mixed_conv2d import MixedConv2d +from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d +from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm +from .padding import get_padding, get_same_padding, pad_same +from .patch_embed import PatchEmbed +from .pool2d_same import AvgPool2dSame, create_pool2d +from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from .selective_kernel import SelectiveKernel +from .separable_conv import SeparableConv2d, SeparableConvNormAct +from .space_to_depth import SpaceToDepthModule +from .split_attn import SplitAttn +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .trace_utils import _assert, _float_to_int +from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/activations.py b/timm/layers/activations.py similarity index 100% rename from timm/models/layers/activations.py rename to timm/layers/activations.py diff --git a/timm/models/layers/activations_jit.py b/timm/layers/activations_jit.py similarity index 100% rename from timm/models/layers/activations_jit.py rename to timm/layers/activations_jit.py diff --git a/timm/models/layers/activations_me.py b/timm/layers/activations_me.py similarity index 100% rename from timm/models/layers/activations_me.py rename to timm/layers/activations_me.py diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/layers/adaptive_avgmax_pool.py similarity index 100% rename from timm/models/layers/adaptive_avgmax_pool.py rename to timm/layers/adaptive_avgmax_pool.py diff --git a/timm/models/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py similarity index 100% rename from timm/models/layers/attention_pool2d.py rename to timm/layers/attention_pool2d.py diff --git a/timm/models/layers/blur_pool.py b/timm/layers/blur_pool.py similarity index 100% rename from timm/models/layers/blur_pool.py rename to timm/layers/blur_pool.py diff --git a/timm/models/layers/bottleneck_attn.py b/timm/layers/bottleneck_attn.py similarity index 100% rename from timm/models/layers/bottleneck_attn.py rename to timm/layers/bottleneck_attn.py diff --git a/timm/models/layers/cbam.py b/timm/layers/cbam.py similarity index 100% rename from timm/models/layers/cbam.py rename to timm/layers/cbam.py diff --git a/timm/models/layers/classifier.py b/timm/layers/classifier.py similarity index 100% rename from timm/models/layers/classifier.py rename to timm/layers/classifier.py diff --git a/timm/models/layers/cond_conv2d.py b/timm/layers/cond_conv2d.py similarity index 100% rename from timm/models/layers/cond_conv2d.py rename to timm/layers/cond_conv2d.py diff --git a/timm/models/layers/config.py b/timm/layers/config.py similarity index 100% rename from timm/models/layers/config.py rename to timm/layers/config.py diff --git a/timm/models/layers/conv2d_same.py b/timm/layers/conv2d_same.py similarity index 100% rename from timm/models/layers/conv2d_same.py rename to timm/layers/conv2d_same.py diff --git a/timm/models/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py similarity index 100% rename from timm/models/layers/conv_bn_act.py rename to timm/layers/conv_bn_act.py diff --git a/timm/models/layers/create_act.py b/timm/layers/create_act.py similarity index 100% rename from timm/models/layers/create_act.py rename to timm/layers/create_act.py diff --git a/timm/models/layers/create_attn.py b/timm/layers/create_attn.py similarity index 100% rename from timm/models/layers/create_attn.py rename to timm/layers/create_attn.py diff --git a/timm/models/layers/create_conv2d.py b/timm/layers/create_conv2d.py similarity index 100% rename from timm/models/layers/create_conv2d.py rename to timm/layers/create_conv2d.py diff --git a/timm/models/layers/create_norm.py b/timm/layers/create_norm.py similarity index 100% rename from timm/models/layers/create_norm.py rename to timm/layers/create_norm.py diff --git a/timm/models/layers/create_norm_act.py b/timm/layers/create_norm_act.py similarity index 100% rename from timm/models/layers/create_norm_act.py rename to timm/layers/create_norm_act.py diff --git a/timm/models/layers/drop.py b/timm/layers/drop.py similarity index 100% rename from timm/models/layers/drop.py rename to timm/layers/drop.py diff --git a/timm/models/layers/eca.py b/timm/layers/eca.py similarity index 100% rename from timm/models/layers/eca.py rename to timm/layers/eca.py diff --git a/timm/models/layers/evo_norm.py b/timm/layers/evo_norm.py similarity index 100% rename from timm/models/layers/evo_norm.py rename to timm/layers/evo_norm.py diff --git a/timm/models/layers/fast_norm.py b/timm/layers/fast_norm.py similarity index 100% rename from timm/models/layers/fast_norm.py rename to timm/layers/fast_norm.py diff --git a/timm/models/layers/filter_response_norm.py b/timm/layers/filter_response_norm.py similarity index 100% rename from timm/models/layers/filter_response_norm.py rename to timm/layers/filter_response_norm.py diff --git a/timm/models/layers/gather_excite.py b/timm/layers/gather_excite.py similarity index 100% rename from timm/models/layers/gather_excite.py rename to timm/layers/gather_excite.py diff --git a/timm/models/layers/global_context.py b/timm/layers/global_context.py similarity index 100% rename from timm/models/layers/global_context.py rename to timm/layers/global_context.py diff --git a/timm/models/layers/halo_attn.py b/timm/layers/halo_attn.py similarity index 100% rename from timm/models/layers/halo_attn.py rename to timm/layers/halo_attn.py diff --git a/timm/models/layers/helpers.py b/timm/layers/helpers.py similarity index 100% rename from timm/models/layers/helpers.py rename to timm/layers/helpers.py diff --git a/timm/models/layers/inplace_abn.py b/timm/layers/inplace_abn.py similarity index 100% rename from timm/models/layers/inplace_abn.py rename to timm/layers/inplace_abn.py diff --git a/timm/models/layers/lambda_layer.py b/timm/layers/lambda_layer.py similarity index 100% rename from timm/models/layers/lambda_layer.py rename to timm/layers/lambda_layer.py diff --git a/timm/models/layers/linear.py b/timm/layers/linear.py similarity index 100% rename from timm/models/layers/linear.py rename to timm/layers/linear.py diff --git a/timm/models/layers/median_pool.py b/timm/layers/median_pool.py similarity index 100% rename from timm/models/layers/median_pool.py rename to timm/layers/median_pool.py diff --git a/timm/models/layers/mixed_conv2d.py b/timm/layers/mixed_conv2d.py similarity index 100% rename from timm/models/layers/mixed_conv2d.py rename to timm/layers/mixed_conv2d.py diff --git a/timm/models/layers/ml_decoder.py b/timm/layers/ml_decoder.py similarity index 100% rename from timm/models/layers/ml_decoder.py rename to timm/layers/ml_decoder.py diff --git a/timm/models/layers/mlp.py b/timm/layers/mlp.py similarity index 100% rename from timm/models/layers/mlp.py rename to timm/layers/mlp.py diff --git a/timm/models/layers/non_local_attn.py b/timm/layers/non_local_attn.py similarity index 100% rename from timm/models/layers/non_local_attn.py rename to timm/layers/non_local_attn.py diff --git a/timm/models/layers/norm.py b/timm/layers/norm.py similarity index 100% rename from timm/models/layers/norm.py rename to timm/layers/norm.py diff --git a/timm/models/layers/norm_act.py b/timm/layers/norm_act.py similarity index 100% rename from timm/models/layers/norm_act.py rename to timm/layers/norm_act.py diff --git a/timm/models/layers/padding.py b/timm/layers/padding.py similarity index 100% rename from timm/models/layers/padding.py rename to timm/layers/padding.py diff --git a/timm/models/layers/patch_embed.py b/timm/layers/patch_embed.py similarity index 100% rename from timm/models/layers/patch_embed.py rename to timm/layers/patch_embed.py diff --git a/timm/models/layers/pool2d_same.py b/timm/layers/pool2d_same.py similarity index 100% rename from timm/models/layers/pool2d_same.py rename to timm/layers/pool2d_same.py diff --git a/timm/models/layers/pos_embed.py b/timm/layers/pos_embed.py similarity index 100% rename from timm/models/layers/pos_embed.py rename to timm/layers/pos_embed.py diff --git a/timm/models/layers/selective_kernel.py b/timm/layers/selective_kernel.py similarity index 100% rename from timm/models/layers/selective_kernel.py rename to timm/layers/selective_kernel.py diff --git a/timm/models/layers/separable_conv.py b/timm/layers/separable_conv.py similarity index 100% rename from timm/models/layers/separable_conv.py rename to timm/layers/separable_conv.py diff --git a/timm/models/layers/space_to_depth.py b/timm/layers/space_to_depth.py similarity index 100% rename from timm/models/layers/space_to_depth.py rename to timm/layers/space_to_depth.py diff --git a/timm/models/layers/split_attn.py b/timm/layers/split_attn.py similarity index 100% rename from timm/models/layers/split_attn.py rename to timm/layers/split_attn.py diff --git a/timm/models/layers/split_batchnorm.py b/timm/layers/split_batchnorm.py similarity index 100% rename from timm/models/layers/split_batchnorm.py rename to timm/layers/split_batchnorm.py diff --git a/timm/models/layers/squeeze_excite.py b/timm/layers/squeeze_excite.py similarity index 100% rename from timm/models/layers/squeeze_excite.py rename to timm/layers/squeeze_excite.py diff --git a/timm/models/layers/std_conv.py b/timm/layers/std_conv.py similarity index 100% rename from timm/models/layers/std_conv.py rename to timm/layers/std_conv.py diff --git a/timm/models/layers/test_time_pool.py b/timm/layers/test_time_pool.py similarity index 100% rename from timm/models/layers/test_time_pool.py rename to timm/layers/test_time_pool.py diff --git a/timm/models/layers/trace_utils.py b/timm/layers/trace_utils.py similarity index 100% rename from timm/models/layers/trace_utils.py rename to timm/layers/trace_utils.py diff --git a/timm/models/layers/weight_init.py b/timm/layers/weight_init.py similarity index 100% rename from timm/models/layers/weight_init.py rename to timm/layers/weight_init.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index b1f82789..ea945ccd 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -65,12 +65,18 @@ from .xception import * from .xception_aligned import * from .xcit import * -from .factory import create_model, parse_model_name, safe_model_name -from .helpers import load_checkpoint, resume_checkpoint, model_parameters -from .layers import TestTimePoolHead, apply_test_time_pool -from .layers import convert_splitbn_model, convert_sync_batchnorm -from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit -from .layers import set_fast_norm -from .pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag -from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\ +from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \ + set_pretrained_download_progress, set_pretrained_check_hash +from ._factory import create_model, parse_model_name, safe_model_name +from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet +from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \ + register_notrace_module, register_notrace_function +from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint +from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub +from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \ + group_modules, group_parameters, checkpoint_seq, adapt_input_conv +from ._pretrained import PretrainedCfg, DefaultCfg, \ + filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag +from ._prune import adapt_model_from_string +from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \ is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/models/_builder.py b/timm/models/_builder.py new file mode 100644 index 00000000..f634650e --- /dev/null +++ b/timm/models/_builder.py @@ -0,0 +1,399 @@ +import dataclasses +import logging +from copy import deepcopy +from typing import Optional, Dict, Callable, Any, Tuple + +from torch import nn as nn +from torch.hub import load_state_dict_from_url + +from timm.models._features import FeatureListNet, FeatureHookNet +from timm.models._features_fx import FeatureGraphNet +from timm.models._helpers import load_state_dict +from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf +from timm.models._manipulate import adapt_input_conv +from timm.models._pretrained import PretrainedCfg +from timm.models._prune import adapt_model_from_file +from timm.models._registry import get_pretrained_cfg + +_logger = logging.getLogger(__name__) + +# Global variables for rarely used pretrained checkpoint download progress and hash check. +# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle. +_DOWNLOAD_PROGRESS = False +_CHECK_HASH = False + + +__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained', + 'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg'] + + +def _resolve_pretrained_source(pretrained_cfg): + cfg_source = pretrained_cfg.get('source', '') + pretrained_url = pretrained_cfg.get('url', None) + pretrained_file = pretrained_cfg.get('file', None) + hf_hub_id = pretrained_cfg.get('hf_hub_id', None) + # resolve where to load pretrained weights from + load_from = '' + pretrained_loc = '' + if cfg_source == 'hf-hub' and has_hf_hub(necessary=True): + # hf-hub specified as source via model identifier + load_from = 'hf-hub' + assert hf_hub_id + pretrained_loc = hf_hub_id + else: + # default source == timm or unspecified + if pretrained_file: + load_from = 'file' + pretrained_loc = pretrained_file + elif pretrained_url: + load_from = 'url' + pretrained_loc = pretrained_url + elif hf_hub_id and has_hf_hub(necessary=True): + # hf-hub available as alternate weight source in default_cfg + load_from = 'hf-hub' + pretrained_loc = hf_hub_id + if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None): + # if a filename override is set, return tuple for location w/ (hub_id, filename) + pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] + return load_from, pretrained_loc + + +def set_pretrained_download_progress(enable=True): + """ Set download progress for pretrained weights on/off (globally). """ + global _DOWNLOAD_PROGRESS + _DOWNLOAD_PROGRESS = enable + + +def set_pretrained_check_hash(enable=True): + """ Set hash checking for pretrained weights on/off (globally). """ + global _CHECK_HASH + _CHECK_HASH = enable + + +def load_custom_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + load_fn: Optional[Callable] = None, +): + r"""Loads a custom (read non .pth) weight file + + Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls + a passed in custom load fun, or the `load_pretrained` model member fn. + + If the object is already present in `model_dir`, it's deserialized and returned. + The default value of `model_dir` is ``/checkpoints`` where + `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + model: The instantiated model to load weights into + pretrained_cfg (dict): Default pretrained model cfg + load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named + 'laod_pretrained' on the model will be called if it exists + """ + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) + if not pretrained_cfg: + _logger.warning("Invalid pretrained config, cannot load weights.") + return + + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if not load_from: + _logger.warning("No pretrained weights exist for this model. Using random initialization.") + return + if load_from == 'hf-hub': # FIXME + _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") + elif load_from == 'url': + pretrained_loc = download_cached_file( + pretrained_loc, + check_hash=_CHECK_HASH, + progress=_DOWNLOAD_PROGRESS + ) + + if load_fn is not None: + load_fn(model, pretrained_loc) + elif hasattr(model, 'load_pretrained'): + model.load_pretrained(pretrained_loc) + else: + _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") + + +def load_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + num_classes: int = 1000, + in_chans: int = 3, + filter_fn: Optional[Callable] = None, + strict: bool = True, +): + """ Load pretrained checkpoint + + Args: + model (nn.Module) : PyTorch model module + pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset + num_classes (int): num_classes for target model + in_chans (int): in_chans for target model + filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) + strict (bool): strict load of checkpoint + + """ + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) + if not pretrained_cfg: + _logger.warning("Invalid pretrained config, cannot load weights.") + return + + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if load_from == 'file': + _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') + state_dict = load_state_dict(pretrained_loc) + elif load_from == 'url': + _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') + state_dict = load_state_dict_from_url( + pretrained_loc, + map_location='cpu', + progress=_DOWNLOAD_PROGRESS, + check_hash=_CHECK_HASH, + ) + elif load_from == 'hf-hub': + _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') + if isinstance(pretrained_loc, (list, tuple)): + state_dict = load_state_dict_from_hf(*pretrained_loc) + else: + state_dict = load_state_dict_from_hf(pretrained_loc) + else: + _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") + return + + if filter_fn is not None: + # for backwards compat with filter fn that take one arg, try one first, the two + try: + state_dict = filter_fn(state_dict) + except TypeError: + state_dict = filter_fn(state_dict, model) + + input_convs = pretrained_cfg.get('first_conv', None) + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + try: + state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) + _logger.info( + f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + _logger.warning( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') + + classifiers = pretrained_cfg.get('classifier', None) + label_offset = pretrained_cfg.get('label_offset', 0) + if classifiers is not None: + if isinstance(classifiers, str): + classifiers = (classifiers,) + if num_classes != pretrained_cfg['num_classes']: + for classifier_name in classifiers: + # completely discard fully connected if model num_classes doesn't match pretrained weights + state_dict.pop(classifier_name + '.weight', None) + state_dict.pop(classifier_name + '.bias', None) + strict = False + elif label_offset > 0: + for classifier_name in classifiers: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + + model.load_state_dict(state_dict, strict=strict) + + +def pretrained_cfg_for_features(pretrained_cfg): + pretrained_cfg = deepcopy(pretrained_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size? + for tr in to_remove: + pretrained_cfg.pop(tr, None) + return pretrained_cfg + + +def _filter_kwargs(kwargs, names): + if not kwargs or not names: + return + for n in names: + kwargs.pop(n, None) + + +def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter): + """ Update the default_cfg and kwargs before passing to model + + Args: + pretrained_cfg: input pretrained cfg (updated in-place) + kwargs: keyword args passed to model build fn (updated in-place) + kwargs_filter: keyword arg keys that must be removed before model __init__ + """ + # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if pretrained_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + + for n in default_kwarg_names: + # for legacy reasons, model __init__args uses img_size + in_chans as separate args while + # pretrained_cfg has one input_size=(C, H ,W) entry + if n == 'img_size': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[-2:]) + elif n == 'in_chans': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[0]) + else: + default_val = pretrained_cfg.get(n, None) + if default_val is not None: + kwargs.setdefault(n, pretrained_cfg[n]) + + # Filter keyword args for task specific model variants (some 'features only' models, etc.) + _filter_kwargs(kwargs, names=kwargs_filter) + + +def resolve_pretrained_cfg( + variant: str, + pretrained_cfg=None, + pretrained_cfg_overlay=None, +) -> PretrainedCfg: + model_with_tag = variant + pretrained_tag = None + if pretrained_cfg: + if isinstance(pretrained_cfg, dict): + # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg + pretrained_cfg = PretrainedCfg(**pretrained_cfg) + elif isinstance(pretrained_cfg, str): + pretrained_tag = pretrained_cfg + pretrained_cfg = None + + # fallback to looking up pretrained cfg in model registry by variant identifier + if not pretrained_cfg: + if pretrained_tag: + model_with_tag = '.'.join([variant, pretrained_tag]) + pretrained_cfg = get_pretrained_cfg(model_with_tag) + + if not pretrained_cfg: + _logger.warning( + f"No pretrained configuration specified for {model_with_tag} model. Using a default." + f" Please add a config to the model pretrained_cfg registry or pass explicitly.") + pretrained_cfg = PretrainedCfg() # instance with defaults + + pretrained_cfg_overlay = pretrained_cfg_overlay or {} + if not pretrained_cfg.architecture: + pretrained_cfg_overlay.setdefault('architecture', variant) + pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay) + + return pretrained_cfg + + +def build_model_with_cfg( + model_cls: Callable, + variant: str, + pretrained: bool, + pretrained_cfg: Optional[Dict] = None, + pretrained_cfg_overlay: Optional[Dict] = None, + model_cfg: Optional[Any] = None, + feature_cfg: Optional[Dict] = None, + pretrained_strict: bool = True, + pretrained_filter_fn: Optional[Callable] = None, + kwargs_filter: Optional[Tuple[str]] = None, + **kwargs, +): + """ Build model with specified default_cfg and optional model_cfg + + This helper fn aids in the construction of a model including: + * handling default_cfg and associated pretrained weight loading + * passing through optional model_cfg for models with config based arch spec + * features_only model adaptation + * pruning config / model adaptation + + Args: + model_cls (nn.Module): model class + variant (str): model variant name + pretrained (bool): load pretrained weights + pretrained_cfg (dict): model's pretrained weight/task config + model_cfg (Optional[Dict]): model's architecture config + feature_cfg (Optional[Dict]: feature extraction adapter config + pretrained_strict (bool): load pretrained weights strictly + pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights + kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model + **kwargs: model args passed through to model __init__ + """ + pruned = kwargs.pop('pruned', False) + features = False + feature_cfg = feature_cfg or {} + + # resolve and update model pretrained config and model kwargs + pretrained_cfg = resolve_pretrained_cfg( + variant, + pretrained_cfg=pretrained_cfg, + pretrained_cfg_overlay=pretrained_cfg_overlay + ) + + # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model + pretrained_cfg = pretrained_cfg.to_dict() + + _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) + + # Setup for feature extraction wrapper done at end of this fn + if kwargs.pop('features_only', False): + features = True + feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) + if 'out_indices' in kwargs: + feature_cfg['out_indices'] = kwargs.pop('out_indices') + + # Instantiate the model + if model_cfg is None: + model = model_cls(**kwargs) + else: + model = model_cls(cfg=model_cfg, **kwargs) + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat + + if pruned: + model = adapt_model_from_file(model, variant) + + # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) + if pretrained: + if pretrained_cfg.get('custom_load', False): + load_custom_pretrained( + model, + pretrained_cfg=pretrained_cfg, + ) + else: + load_pretrained( + model, + pretrained_cfg=pretrained_cfg, + num_classes=num_classes_pretrained, + in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, + strict=pretrained_strict, + ) + + # Wrap the model in a feature extraction module if enabled + if features: + feature_cls = FeatureListNet + if 'feature_cls' in feature_cfg: + feature_cls = feature_cfg.pop('feature_cls') + if isinstance(feature_cls, str): + feature_cls = feature_cls.lower() + if 'hook' in feature_cls: + feature_cls = FeatureHookNet + elif feature_cls == 'fx': + feature_cls = FeatureGraphNet + else: + assert False, f'Unknown feature class {feature_cls}' + model = feature_cls(model, **feature_cfg) + model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat + + return model diff --git a/timm/models/efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py similarity index 99% rename from timm/models/efficientnet_blocks.py rename to timm/models/_efficientnet_blocks.py index 34a31757..92b849e4 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -2,13 +2,12 @@ Hacked together by / Copyright 2019, Ross Wightman """ -import math import torch import torch.nn as nn from torch.nn import functional as F -from .layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer +from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer __all__ = [ 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] diff --git a/timm/models/efficientnet_builder.py b/timm/models/_efficientnet_builder.py similarity index 99% rename from timm/models/efficientnet_builder.py rename to timm/models/_efficientnet_builder.py index 67d15a86..e6cd05ae 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -14,8 +14,8 @@ from functools import partial import torch.nn as nn -from .efficientnet_blocks import * -from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible +from ._efficientnet_blocks import * +from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] diff --git a/timm/models/_factory.py b/timm/models/_factory.py new file mode 100644 index 00000000..a8092419 --- /dev/null +++ b/timm/models/_factory.py @@ -0,0 +1,103 @@ +import os +from typing import Any, Dict, Optional, Union +from urllib.parse import urlsplit + +from timm.layers import set_layer_config +from ._pretrained import PretrainedCfg, split_model_name_tag +from ._helpers import load_checkpoint +from ._hub import load_model_config_from_hf +from ._registry import is_model, model_entrypoint + + +__all__ = ['parse_model_name', 'safe_model_name', 'create_model'] + + +def parse_model_name(model_name): + if model_name.startswith('hf_hub'): + # NOTE for backwards compat, deprecate hf_hub use + model_name = model_name.replace('hf_hub', 'hf-hub') + parsed = urlsplit(model_name) + assert parsed.scheme in ('', 'timm', 'hf-hub') + if parsed.scheme == 'hf-hub': + # FIXME may use fragment as revision, currently `@` in URI path + return parsed.scheme, parsed.path + else: + model_name = os.path.split(parsed.path)[-1] + return 'timm', model_name + + +def safe_model_name(model_name, remove_source=True): + # return a filename / path safe model name + def make_safe(name): + return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') + if remove_source: + model_name = parse_model_name(model_name)[-1] + return make_safe(model_name) + + +def create_model( + model_name: str, + pretrained: bool = False, + pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, + pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, + checkpoint_path: str = '', + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + **kwargs, +): + """Create a model + + Lookup model's entrypoint function and pass relevant args to create a new model. + + **kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg() + and then the model class __init__(). kwargs values set to None are pruned before passing. + + Args: + model_name (str): name of model to instantiate + pretrained (bool): load pretrained ImageNet-1k weights if true + pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model + pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these + checkpoint_path (str): path of checkpoint to load _after_ the model is initialized + scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) + exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) + no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) + + Keyword Args: + drop_rate (float): dropout rate for training (default: 0.0) + global_pool (str): global pool type (default: 'avg') + **: other kwargs are consumed by builder or model __init__() + """ + # Parameters that aren't supported by all models or are intended to only override model defaults if set + # should default to None in command line args/cfg. Remove them if they are present and not set so that + # non-supporting models don't break and default args remain in effect. + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + model_source, model_name = parse_model_name(model_name) + if model_source == 'hf-hub': + assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.' + # For model names specified in the form `hf-hub:path/architecture_name@revision`, + # load model weights + pretrained_cfg from Hugging Face hub. + pretrained_cfg, model_name = load_model_config_from_hf(model_name) + else: + model_name, pretrained_tag = split_model_name_tag(model_name) + if not pretrained_cfg: + # a valid pretrained_cfg argument takes priority over tag in model name + pretrained_cfg = pretrained_tag + + if not is_model(model_name): + raise RuntimeError('Unknown model (%s)' % model_name) + + create_fn = model_entrypoint(model_name) + with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): + model = create_fn( + pretrained=pretrained, + pretrained_cfg=pretrained_cfg, + pretrained_cfg_overlay=pretrained_cfg_overlay, + **kwargs, + ) + + if checkpoint_path: + load_checkpoint(model, checkpoint_path) + + return model diff --git a/timm/models/_features.py b/timm/models/_features.py new file mode 100644 index 00000000..59b080cd --- /dev/null +++ b/timm/models/_features.py @@ -0,0 +1,287 @@ +""" PyTorch Feature Extraction Helpers + +A collection of classes, functions, modules to help extract features from models +and provide a common interface for describing them. + +The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter +https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict, defaultdict +from copy import deepcopy +from functools import partial +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn + + +__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] + + +class FeatureInfo: + + def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): + prev_reduction = 1 + for fi in feature_info: + # sanity check the mandatory fields, there may be additional fields depending on the model + assert 'num_chs' in fi and fi['num_chs'] > 0 + assert 'reduction' in fi and fi['reduction'] >= prev_reduction + prev_reduction = fi['reduction'] + assert 'module' in fi + self.out_indices = out_indices + self.info = feature_info + + def from_other(self, out_indices: Tuple[int]): + return FeatureInfo(deepcopy(self.info), out_indices) + + def get(self, key, idx=None): + """ Get value by key at specified index (indices) + if idx == None, returns value for key at each output index + if idx is an integer, return value for that feature module index (ignoring output indices) + if idx is a list/tupple, return value for each module index (ignoring output indices) + """ + if idx is None: + return [self.info[i][key] for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i][key] for i in idx] + else: + return self.info[idx][key] + + def get_dicts(self, keys=None, idx=None): + """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) + """ + if idx is None: + if keys is None: + return [self.info[i] for i in self.out_indices] + else: + return [{k: self.info[i][k] for k in keys} for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] + else: + return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} + + def channels(self, idx=None): + """ feature channels accessor + """ + return self.get('num_chs', idx) + + def reduction(self, idx=None): + """ feature reduction (output stride) accessor + """ + return self.get('reduction', idx) + + def module_name(self, idx=None): + """ feature module name accessor + """ + return self.get('module', idx) + + def __getitem__(self, item): + return self.info[item] + + def __len__(self): + return len(self.info) + + +class FeatureHooks: + """ Feature Hook Helper + + This module helps with the setup and extraction of hooks for extracting features from + internal nodes in a model by node name. This works quite well in eager Python but needs + redesign for torchscript. + """ + + def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): + # setup feature hooks + modules = {k: v for k, v in named_modules} + for i, h in enumerate(hooks): + hook_name = h['module'] + m = modules[hook_name] + hook_id = out_map[i] if out_map else hook_name + hook_fn = partial(self._collect_output_hook, hook_id) + hook_type = h.get('hook_type', default_hook_type) + if hook_type == 'forward_pre': + m.register_forward_pre_hook(hook_fn) + elif hook_type == 'forward': + m.register_forward_hook(hook_fn) + else: + assert False, "Unsupported hook type" + self._feature_outputs = defaultdict(OrderedDict) + + def _collect_output_hook(self, hook_id, *args): + x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][hook_id] = x + + def get_output(self, device) -> Dict[str, torch.tensor]: + output = self._feature_outputs[device] + self._feature_outputs[device] = OrderedDict() # clear after reading + return output + + +def _module_list(module, flatten_sequential=False): + # a yield/iter would be better for this but wouldn't be compatible with torchscript + ml = [] + for name, module in module.named_children(): + if flatten_sequential and isinstance(module, nn.Sequential): + # first level of Sequential containers is flattened into containing model + for child_name, child_module in module.named_children(): + combined = [name, child_name] + ml.append(('_'.join(combined), '.'.join(combined), child_module)) + else: + ml.append((name, name, module)) + return ml + + +def _get_feature_info(net, out_indices): + feature_info = getattr(net, 'feature_info') + if isinstance(feature_info, FeatureInfo): + return feature_info.from_other(out_indices) + elif isinstance(feature_info, (list, tuple)): + return FeatureInfo(net.feature_info, out_indices) + else: + assert False, "Provided feature_info is not valid" + + +def _get_return_layers(feature_info, out_map): + module_names = feature_info.module_name() + return_layers = {} + for i, name in enumerate(module_names): + return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] + return return_layers + + +class FeatureDictNet(nn.ModuleDict): + """ Feature extractor with OrderedDict return + + Wrap a model and extract features as specified by the out indices, the network is + partially re-built from contained modules. + + There is a strong assumption that the modules have been registered into the model in the same + order as they are used. There should be no reuse of the same nn.Module more than once, including + trivial modules like `self.relu = nn.ReLU`. + + Only submodules that are directly assigned to the model class (`model.feature1`) or at most + one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. + All Sequential containers that are directly assigned to the original model will have their + modules assigned to this module with the name `model.features.1` being changed to `model.features_1` + + Arguments: + model (nn.Module): model from which we will extract the features + out_indices (tuple[int]): model output indices to extract features for + out_map (sequence): list or tuple specifying desired return id for each out index, + otherwise str(index) is used + feature_concat (bool): whether to concatenate intermediate features that are lists or tuples + vs select element [0] + flatten_sequential (bool): whether to flatten sequential modules assigned to model + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureDictNet, self).__init__() + self.feature_info = _get_feature_info(model, out_indices) + self.concat = feature_concat + self.return_layers = {} + return_layers = _get_return_layers(self.feature_info, out_map) + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = set(return_layers.keys()) + layers = OrderedDict() + for new_name, old_name, module in modules: + layers[new_name] = module + if old_name in remaining: + # return id has to be consistently str type for torchscript + self.return_layers[new_name] = str(return_layers[old_name]) + remaining.remove(old_name) + if not remaining: + break + assert not remaining and len(self.return_layers) == len(return_layers), \ + f'Return layers ({remaining}) are not present in model' + self.update(layers) + + def _collect(self, x) -> (Dict[str, torch.Tensor]): + out = OrderedDict() + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + out_id = self.return_layers[name] + if isinstance(x, (tuple, list)): + # If model tap is a tuple or list, concat or select first element + # FIXME this may need to be more generic / flexible for some nets + out[out_id] = torch.cat(x, 1) if self.concat else x[0] + else: + out[out_id] = x + return out + + def forward(self, x) -> Dict[str, torch.Tensor]: + return self._collect(x) + + +class FeatureListNet(FeatureDictNet): + """ Feature extractor with list return + + See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. + In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureListNet, self).__init__( + model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, + flatten_sequential=flatten_sequential) + + def forward(self, x) -> (List[torch.Tensor]): + return list(self._collect(x).values()) + + +class FeatureHookNet(nn.ModuleDict): + """ FeatureHookNet + + Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. + + If `no_rewrite` is True, features are extracted via hooks without modifying the underlying + network in any way. + + If `no_rewrite` is False, the model will be re-written as in the + FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. + + FIXME this does not currently work with Torchscript, see FeatureHooks class + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, + feature_concat=False, flatten_sequential=False, default_hook_type='forward'): + super(FeatureHookNet, self).__init__() + assert not torch.jit.is_scripting() + self.feature_info = _get_feature_info(model, out_indices) + self.out_as_dict = out_as_dict + layers = OrderedDict() + hooks = [] + if no_rewrite: + assert not flatten_sequential + if hasattr(model, 'reset_classifier'): # make sure classifier is removed? + model.reset_classifier(0) + layers['body'] = model + hooks.extend(self.feature_info.get_dicts()) + else: + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type + for f in self.feature_info.get_dicts()} + for new_name, old_name, module in modules: + layers[new_name] = module + for fn, fm in module.named_modules(prefix=old_name): + if fn in remaining: + hooks.append(dict(module=fn, hook_type=remaining[fn])) + del remaining[fn] + if not remaining: + break + assert not remaining, f'Return layers ({remaining}) are not present in model' + self.update(layers) + self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) + + def forward(self, x): + for name, module in self.items(): + x = module(x) + out = self.hooks.get_output(x.device) + return out if self.out_as_dict else list(out.values()) diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py new file mode 100644 index 00000000..10670a1d --- /dev/null +++ b/timm/models/_features_fx.py @@ -0,0 +1,110 @@ +""" PyTorch FX Based Feature Extraction Helpers +Using https://pytorch.org/vision/stable/feature_extraction.html +""" +from typing import Callable, List, Dict, Union, Type + +import torch +from torch import nn + +from ._features import _get_feature_info + +try: + from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor + has_fx_feature_extraction = True +except ImportError: + has_fx_feature_extraction = False + +# Layers we went to treat as leaf modules +from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame +from timm.layers.non_local_attn import BilinearAttnTransform +from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame + +# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here +# BUT modules from timm.models should use the registration mechanism below +_leaf_modules = { + BilinearAttnTransform, # reason: flow control t <= 1 + # Reason: get_same_padding has a max which raises a control flow error + Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, + CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) +} + +try: + from timm.layers import InplaceAbn + _leaf_modules.add(InplaceAbn) +except ImportError: + pass + + +__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor', + 'FeatureGraphNet', 'GraphExtractNet'] + + +def register_notrace_module(module: Type[nn.Module]): + """ + Any module not under timm.models.layers should get this decorator if we don't want to trace through it. + """ + _leaf_modules.add(module) + return module + + +# Functions we want to autowrap (treat them as leaves) +_autowrap_functions = set() + + +def register_notrace_function(func: Callable): + """ + Decorator for functions which ought not to be traced through + """ + _autowrap_functions.add(func) + return func + + +def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' + return _create_feature_extractor( + model, return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} + ) + + +class FeatureGraphNet(nn.Module): + """ A FX Graph based feature extractor that works with the model feature_info metadata + """ + def __init__(self, model, out_indices, out_map=None): + super().__init__() + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' + self.feature_info = _get_feature_info(model, out_indices) + if out_map is not None: + assert len(out_map) == len(out_indices) + return_nodes = { + info['module']: out_map[i] if out_map is not None else info['module'] + for i, info in enumerate(self.feature_info) if i in out_indices} + self.graph_module = create_feature_extractor(model, return_nodes) + + def forward(self, x): + return list(self.graph_module(x).values()) + + +class GraphExtractNet(nn.Module): + """ A standalone feature extraction wrapper that maps dict -> list or single tensor + NOTE: + * one can use feature_extractor directly if dictionary output is desired + * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info + metadata for builtin feature extraction mode + * create_feature_extractor can be used directly if dictionary output is acceptable + + Args: + model: model to extract features from + return_nodes: node names to return features from (dict or list) + squeeze_out: if only one output, and output in list format, flatten to single tensor + """ + def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): + super().__init__() + self.squeeze_out = squeeze_out + self.graph_module = create_feature_extractor(model, return_nodes) + + def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: + out = list(self.graph_module(x).values()) + if self.squeeze_out and len(out) == 1: + return out[0] + return out diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py new file mode 100644 index 00000000..995292aa --- /dev/null +++ b/timm/models/_helpers.py @@ -0,0 +1,115 @@ +""" Model creation / weight loading / state_dict helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import os +from collections import OrderedDict + +import torch + +import timm.models._builder + +_logger = logging.getLogger(__name__) + +__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint'] + + +def clean_state_dict(state_dict): + # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training + cleaned_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] if k.startswith('module.') else k + cleaned_state_dict[name] = v + return cleaned_state_dict + + +def load_state_dict(checkpoint_path, use_ema=True): + if checkpoint_path and os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + state_dict_key = '' + if isinstance(checkpoint, dict): + if use_ema and checkpoint.get('state_dict_ema', None) is not None: + state_dict_key = 'state_dict_ema' + elif use_ema and checkpoint.get('model_ema', None) is not None: + state_dict_key = 'model_ema' + elif 'state_dict' in checkpoint: + state_dict_key = 'state_dict' + elif 'model' in checkpoint: + state_dict_key = 'model' + state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint) + _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) + return state_dict + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False): + if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, 'load_pretrained'): + timm.models._model_builder.load_pretrained(checkpoint_path) + else: + raise NotImplementedError('Model cannot load numpy checkpoint') + return + state_dict = load_state_dict(checkpoint_path, use_ema) + if remap: + state_dict = remap_checkpoint(model, state_dict) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def remap_checkpoint(model, state_dict, allow_reshape=True): + """ remap checkpoint by iterating over state dicts in order (ignoring original keys). + This assumes models (and originating state dict) were created with params registered in same order. + """ + out_dict = {} + for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): + assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + if va.shape != vb.shape: + if allow_reshape: + vb = vb.reshape(va.shape) + else: + assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + out_dict[ka] = vb + return out_dict + + +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): + resume_epoch = None + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + _logger.info('Restoring model state from checkpoint...') + state_dict = clean_state_dict(checkpoint['state_dict']) + model.load_state_dict(state_dict) + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + else: + model.load_state_dict(checkpoint) + if log_info: + _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + diff --git a/timm/models/_hub.py b/timm/models/_hub.py new file mode 100644 index 00000000..e6b7d558 --- /dev/null +++ b/timm/models/_hub.py @@ -0,0 +1,220 @@ +import json +import logging +import os +from functools import partial +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Union + +import torch +from torch.hub import HASH_REGEX, download_url_to_file, urlparse + +try: + from torch.hub import get_dir +except ImportError: + from torch.hub import _get_torch_home as get_dir + +from timm import __version__ +from timm.models._pretrained import filter_pretrained_cfg + +try: + from huggingface_hub import ( + create_repo, get_hf_file_metadata, + hf_hub_download, hf_hub_url, + repo_type_and_id_from_hf_id, upload_folder) + from huggingface_hub.utils import EntryNotFoundError + hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + +_logger = logging.getLogger(__name__) + +__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf', + 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub'] + + +def get_cache_dir(child_dir=''): + """ + Returns the location of the directory where models are cached (and creates it if necessary). + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + hub_dir = get_dir() + child_dir = () if not child_dir else (child_dir,) + model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) + os.makedirs(model_dir, exist_ok=True) + return model_dir + + +def download_cached_file(url, check_hash=True, progress=False): + if isinstance(url, (list, tuple)): + url, filename = url + else: + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(get_cache_dir(), filename) + if not os.path.exists(cached_file): + _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + return cached_file + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def hf_split(hf_id): + # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme + rev_split = hf_id.split('@') + assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' + hf_model_id = rev_split[0] + hf_revision = rev_split[-1] if len(rev_split) > 1 else None + return hf_model_id, hf_revision + + +def load_cfg_from_json(json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + +def _download_from_hf(model_id: str, filename: str): + hf_model_id, hf_revision = hf_split(model_id) + return hf_hub_download(hf_model_id, filename, revision=hf_revision) + + +def load_model_config_from_hf(model_id: str): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, 'config.json') + + hf_config = load_cfg_from_json(cached_file) + if 'pretrained_cfg' not in hf_config: + # old form, pull pretrain_cfg out of the base dict + pretrained_cfg = hf_config + hf_config = {} + hf_config['architecture'] = pretrained_cfg.pop('architecture') + hf_config['num_features'] = pretrained_cfg.pop('num_features', None) + if 'labels' in pretrained_cfg: + hf_config['label_name'] = pretrained_cfg.pop('labels') + hf_config['pretrained_cfg'] = pretrained_cfg + + # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now + pretrained_cfg = hf_config['pretrained_cfg'] + pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation + pretrained_cfg['source'] = 'hf-hub' + if 'num_classes' in hf_config: + # model should be created with parent num_classes if they exist + pretrained_cfg['num_classes'] = hf_config['num_classes'] + model_name = hf_config['architecture'] + + return pretrained_cfg, model_name + + +def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, filename) + state_dict = torch.load(cached_file, map_location='cpu') + return state_dict + + +def save_for_hf(model, save_directory, model_config=None): + assert has_hf_hub(True) + model_config = model_config or {} + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + weights_path = save_directory / 'pytorch_model.bin' + torch.save(model.state_dict(), weights_path) + + config_path = save_directory / 'config.json' + hf_config = {} + pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) + # set some values at root config level + hf_config['architecture'] = pretrained_cfg.pop('architecture') + hf_config['num_classes'] = model_config.get('num_classes', model.num_classes) + hf_config['num_features'] = model_config.get('num_features', model.num_features) + hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None)) + + if 'label' in model_config: + _logger.warning( + "'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. " + "Using provided 'label' field as 'label_name'.") + model_config['label_name'] = model_config.pop('label') + + label_name = model_config.pop('label_name', None) + if label_name: + assert isinstance(label_name, (dict, list, tuple)) + # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages) + # can be a dict id: name if there are id gaps, or tuple/list if no gaps. + hf_config['label_name'] = model_config['label_name'] + + display_name = model_config.pop('display_name', None) + if display_name: + assert isinstance(display_name, dict) + # map label_name -> user interface display name + hf_config['display_name'] = model_config['display_name'] + + hf_config['pretrained_cfg'] = pretrained_cfg + hf_config.update(model_config) + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def push_to_hf_hub( + model, + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_config: Optional[dict] = None, +): + # Create repo if it doesn't exist yet + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) + + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if README file already exist in repo + try: + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf(model, tmpdir, model_config=model_config) + + # Add readme if it does not exist + if not has_readme: + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / "README.md" + readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}' + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) diff --git a/timm/models/_manipulate.py b/timm/models/_manipulate.py new file mode 100644 index 00000000..192979fc --- /dev/null +++ b/timm/models/_manipulate.py @@ -0,0 +1,258 @@ +import collections.abc +import math +import re +from collections import defaultdict +from itertools import chain +from typing import Callable, Union, Dict + +import torch +from torch import nn as nn +from torch.utils.checkpoint import checkpoint + +__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv', + 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq'] + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() + + +def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): + if not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + yield name, module + + +def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False): + if module._parameters and not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules_with_params( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if module._parameters and depth_first and include_root: + yield name, module + + +MATCH_PREV_GROUP = (99999,) + + +def group_with_matcher( + named_objects, + group_matcher: Union[Dict, Callable], + output_values: bool = False, + reverse: bool = False +): + if isinstance(group_matcher, dict): + # dictionary matcher contains a dict of raw-string regex expr that must be compiled + compiled = [] + for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): + if mspec is None: + continue + # map all matching specifications into 3-tuple (compiled re, prefix, suffix) + if isinstance(mspec, (tuple, list)): + # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) + for sspec in mspec: + compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] + else: + compiled += [(re.compile(mspec), (group_ordinal,), None)] + group_matcher = compiled + + def _get_grouping(name): + if isinstance(group_matcher, (list, tuple)): + for match_fn, prefix, suffix in group_matcher: + r = match_fn.match(name) + if r: + parts = (prefix, r.groups(), suffix) + # map all tuple elem to int for numeric sort, filter out None entries + return tuple(map(float, chain.from_iterable(filter(None, parts)))) + return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal + else: + ord = group_matcher(name) + if not isinstance(ord, collections.abc.Iterable): + return ord, + return tuple(ord) + + # map layers into groups via ordinals (ints or tuples of ints) from matcher + grouping = defaultdict(list) + for k, v in named_objects: + grouping[_get_grouping(k)].append(v if output_values else k) + + # remap to integers + layer_id_to_param = defaultdict(list) + lid = -1 + for k in sorted(filter(lambda x: x is not None, grouping.keys())): + if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: + lid += 1 + layer_id_to_param[lid].extend(grouping[k]) + + if reverse: + assert not output_values, "reverse mapping only sensible for name output" + # output reverse mapping + param_to_layer_id = {} + for lid, lm in layer_id_to_param.items(): + for n in lm: + param_to_layer_id[n] = lid + return param_to_layer_id + + return layer_id_to_param + + +def group_parameters( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) + + +def group_modules( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) + + +def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'): + prefix_is_tuple = isinstance(prefix, tuple) + if isinstance(module_types, str): + if module_types == 'container': + module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict) + else: + module_types = (nn.Sequential,) + for name, module in named_modules: + if depth and isinstance(module, module_types): + yield from flatten_modules( + module.named_children(), + depth - 1, + prefix=(name,) if prefix_is_tuple else name, + module_types=module_types, + ) + else: + if prefix_is_tuple: + name = prefix + (name,) + yield name, module + else: + if prefix: + name = '.'.join([prefix, name]) + yield name, module + + +def checkpoint_seq( + functions, + x, + every=1, + flatten=False, + skip_last=False, + preserve_rng_state=True +): + r"""A helper function for checkpointing sequential models. + + Sequential models execute a list of modules/functions in order + (sequentially). Therefore, we can divide such a sequence into segments + and checkpoint each segment. All segments except run in :func:`torch.no_grad` + manner, i.e., not storing the intermediate activations. The inputs of each + checkpointed segment will be saved for re-running the segment in the backward pass. + + See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. + + .. warning:: + Checkpointing currently only supports :func:`torch.autograd.backward` + and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` + is not supported. + + .. warning: + At least one of the inputs needs to have :code:`requires_grad=True` if + grads are needed for model inputs, otherwise the checkpointed part of the + model won't have gradients. + + Args: + functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. + x: A Tensor that is input to :attr:`functions` + every: checkpoint every-n functions (default: 1) + flatten (bool): flatten nn.Sequential of nn.Sequentials + skip_last (bool): skip checkpointing the last function in the sequence if True + preserve_rng_state (bool, optional, default=True): Omit stashing and restoring + the RNG state during each checkpoint. + + Returns: + Output of running :attr:`functions` sequentially on :attr:`*inputs` + + Example: + >>> model = nn.Sequential(...) + >>> input_var = checkpoint_seq(model, input_var, every=2) + """ + def run_function(start, end, functions): + def forward(_x): + for j in range(start, end + 1): + _x = functions[j](_x) + return _x + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = functions.children() + if flatten: + functions = chain.from_iterable(functions) + if not isinstance(functions, (tuple, list)): + functions = tuple(functions) + + num_checkpointed = len(functions) + if skip_last: + num_checkpointed -= 1 + end = -1 + for start in range(0, num_checkpointed, every): + end = min(start + every - 1, num_checkpointed - 1) + x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) + if skip_last: + return run_function(end + 1, len(functions) - 1, functions)(x) + return x + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight diff --git a/timm/models/pretrained.py b/timm/models/_pretrained.py similarity index 97% rename from timm/models/pretrained.py rename to timm/models/_pretrained.py index 2ca7ac5a..b5ecbc50 100644 --- a/timm/models/pretrained.py +++ b/timm/models/_pretrained.py @@ -4,6 +4,9 @@ from dataclasses import dataclass, field, replace, asdict from typing import Any, Deque, Dict, Tuple, Optional, Union +__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs'] + + @dataclass class PretrainedCfg: """ diff --git a/timm/models/_prune.py b/timm/models/_prune.py new file mode 100644 index 00000000..4e744dec --- /dev/null +++ b/timm/models/_prune.py @@ -0,0 +1,113 @@ +import os +from copy import deepcopy + +from torch import nn as nn + +from timm.layers import Conv2dSame, BatchNormAct2d, Linear + +__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file'] + + +def extract_layer(model, layer): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + if not hasattr(model, 'module') and layer[0] == 'module': + layer = layer[1:] + for l in layer: + if hasattr(module, l): + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + else: + return module + return module + + +def set_layer(model, layer, val): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + lst_index = 0 + module2 = module + for l in layer: + if hasattr(module2, l): + if not l.isdigit(): + module2 = getattr(module2, l) + else: + module2 = module2[int(l)] + lst_index += 1 + lst_index -= 1 + for l in layer[:lst_index]: + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + l = layer[lst_index] + setattr(module, l, val) + + +def adapt_model_from_string(parent_module, model_string): + separator = '***' + state_dict = {} + lst_shape = model_string.split(separator) + for k in lst_shape: + k = k.split(':') + key = k[0] + shape = k[1][1:-1].split(',') + if shape[0] != '': + state_dict[key] = [int(i) for i in shape] + + new_module = deepcopy(parent_module) + for n, m in parent_module.named_modules(): + old_module = extract_layer(parent_module, n) + if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): + if isinstance(old_module, Conv2dSame): + conv = Conv2dSame + else: + conv = nn.Conv2d + s = state_dict[n + '.weight'] + in_channels = s[1] + out_channels = s[0] + g = 1 + if old_module.groups > 1: + in_channels = out_channels + g = in_channels + new_conv = conv( + in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, + bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, + groups=g, stride=old_module.stride) + set_layer(new_module, n, new_conv) + elif isinstance(old_module, BatchNormAct2d): + new_bn = BatchNormAct2d( + state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + new_bn.drop = old_module.drop + new_bn.act = old_module.act + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.BatchNorm2d): + new_bn = nn.BatchNorm2d( + num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.Linear): + # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? + num_features = state_dict[n + '.weight'][1] + new_fc = Linear( + in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) + set_layer(new_module, n, new_fc) + if hasattr(new_module, 'num_features'): + new_module.num_features = num_features + new_module.eval() + parent_module.eval() + + return new_module + + +def adapt_model_from_file(parent_module, model_variant): + adapt_file = os.path.join(os.path.dirname(__file__), '_pruned', model_variant + '.txt') + with open(adapt_file, 'r') as f: + return adapt_model_from_string(parent_module, f.read().strip()) diff --git a/timm/models/pruned/ecaresnet101d_pruned.txt b/timm/models/_pruned/ecaresnet101d_pruned.txt similarity index 100% rename from timm/models/pruned/ecaresnet101d_pruned.txt rename to timm/models/_pruned/ecaresnet101d_pruned.txt diff --git a/timm/models/pruned/ecaresnet50d_pruned.txt b/timm/models/_pruned/ecaresnet50d_pruned.txt similarity index 100% rename from timm/models/pruned/ecaresnet50d_pruned.txt rename to timm/models/_pruned/ecaresnet50d_pruned.txt diff --git a/timm/models/pruned/efficientnet_b1_pruned.txt b/timm/models/_pruned/efficientnet_b1_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b1_pruned.txt rename to timm/models/_pruned/efficientnet_b1_pruned.txt diff --git a/timm/models/pruned/efficientnet_b2_pruned.txt b/timm/models/_pruned/efficientnet_b2_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b2_pruned.txt rename to timm/models/_pruned/efficientnet_b2_pruned.txt diff --git a/timm/models/pruned/efficientnet_b3_pruned.txt b/timm/models/_pruned/efficientnet_b3_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b3_pruned.txt rename to timm/models/_pruned/efficientnet_b3_pruned.txt diff --git a/timm/models/_registry.py b/timm/models/_registry.py new file mode 100644 index 00000000..fc7b3437 --- /dev/null +++ b/timm/models/_registry.py @@ -0,0 +1,212 @@ +""" Model Registry +Hacked together by / Copyright 2020 Ross Wightman +""" + +import fnmatch +import re +import sys +from collections import defaultdict, deque +from copy import deepcopy +from typing import List, Optional, Union, Tuple + +from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag + +__all__ = [ + 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', + 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] + +_module_to_models = defaultdict(set) # dict of sets to check membership of model in module +_model_to_module = {} # mapping of model names to module names +_model_entrypoints = {} # mapping of model names to architecture entrypoint fns +_model_has_pretrained = set() # set of model names that have pretrained weight url present +_model_default_cfgs = dict() # central repo for model arch -> default cfg objects +_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs +_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names + + +def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]: + return split_model_name_tag(model_name)[0] + + +def register_model(fn): + # lookup containing module + mod = sys.modules[fn.__module__] + module_name_split = fn.__module__.split('.') + module_name = module_name_split[-1] if len(module_name_split) else '' + + # add model to __all__ in module + model_name = fn.__name__ + if hasattr(mod, '__all__'): + mod.__all__.append(model_name) + else: + mod.__all__ = [model_name] + + # add entries to registry dict/sets + _model_entrypoints[model_name] = fn + _model_to_module[model_name] = module_name + _module_to_models[module_name].add(model_name) + if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: + # this will catch all models that have entrypoint matching cfg key, but miss any aliasing + # entrypoints or non-matching combos + cfg = mod.default_cfgs[model_name] + if not isinstance(cfg, DefaultCfg): + # new style default cfg dataclass w/ multiple entries per model-arch + assert isinstance(cfg, dict) + # old style cfg dict per model-arch + cfg = PretrainedCfg(**cfg) + cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg}) + + for tag_idx, tag in enumerate(cfg.tags): + is_default = tag_idx == 0 + pretrained_cfg = cfg.cfgs[tag] + if is_default: + _model_pretrained_cfgs[model_name] = pretrained_cfg + if pretrained_cfg.has_weights: + # add tagless entry if it's default and has weights + _model_has_pretrained.add(model_name) + if tag: + model_name_tag = '.'.join([model_name, tag]) + _model_pretrained_cfgs[model_name_tag] = pretrained_cfg + if pretrained_cfg.has_weights: + # add model w/ tag if tag is valid + _model_has_pretrained.add(model_name_tag) + _model_with_tags[model_name].append(model_name_tag) + else: + _model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances) + + _model_default_cfgs[model_name] = cfg + + return fn + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def list_models( + filter: Union[str, List[str]] = '', + module: str = '', + pretrained=False, + exclude_filters: str = '', + name_matches_cfg: bool = False, + include_tags: Optional[bool] = None, +): + """ Return list of available model names, sorted alphabetically + + Args: + filter (str) - Wildcard filter string that works with fnmatch + module (str) - Limit model selection to a specific submodule (ie 'vision_transformer') + pretrained (bool) - Include only models with valid pretrained weights if True + exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) + include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults + set to True when pretrained=True else False (default: None) + Example: + model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' + model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module + """ + if include_tags is None: + # FIXME should this be default behaviour? or default to include_tags=True? + include_tags = pretrained + + if module: + all_models = list(_module_to_models[module]) + else: + all_models = _model_entrypoints.keys() + + if include_tags: + # expand model names to include names w/ pretrained tags + models_with_tags = [] + for m in all_models: + models_with_tags.extend(_model_with_tags[m]) + all_models = models_with_tags + + if filter: + models = [] + include_filters = filter if isinstance(filter, (tuple, list)) else [filter] + for f in include_filters: + include_models = fnmatch.filter(all_models, f) # include these models + if len(include_models): + models = set(models).union(include_models) + else: + models = all_models + + if exclude_filters: + if not isinstance(exclude_filters, (tuple, list)): + exclude_filters = [exclude_filters] + for xf in exclude_filters: + exclude_models = fnmatch.filter(models, xf) # exclude these models + if len(exclude_models): + models = set(models).difference(exclude_models) + + if pretrained: + models = _model_has_pretrained.intersection(models) + + if name_matches_cfg: + models = set(_model_pretrained_cfgs).intersection(models) + + return list(sorted(models, key=_natural_key)) + + +def list_pretrained( + filter: Union[str, List[str]] = '', + exclude_filters: str = '', +): + return list_models( + filter=filter, + pretrained=True, + exclude_filters=exclude_filters, + include_tags=True, + ) + + +def is_model(model_name): + """ Check if a model name exists + """ + arch_name = get_arch_name(model_name) + return arch_name in _model_entrypoints + + +def model_entrypoint(model_name, module_filter: Optional[str] = None): + """Fetch a model entrypoint for specified model name + """ + arch_name = get_arch_name(model_name) + if module_filter and arch_name not in _module_to_models.get(module_filter, {}): + raise RuntimeError(f'Model ({model_name} not found in module {module_filter}.') + return _model_entrypoints[arch_name] + + +def list_modules(): + """ Return list of module names that contain models / model entrypoints + """ + modules = _module_to_models.keys() + return list(sorted(modules)) + + +def is_model_in_modules(model_name, module_names): + """Check if a model exists within a subset of modules + Args: + model_name (str) - name of model to check + module_names (tuple, list, set) - names of modules to search in + """ + arch_name = get_arch_name(model_name) + assert isinstance(module_names, (tuple, list, set)) + return any(arch_name in _module_to_models[n] for n in module_names) + + +def is_model_pretrained(model_name): + return model_name in _model_has_pretrained + + +def get_pretrained_cfg(model_name): + if model_name in _model_pretrained_cfgs: + return deepcopy(_model_pretrained_cfgs[model_name]) + raise RuntimeError(f'No pretrained config exists for model {model_name}.') + + +def get_pretrained_cfg_value(model_name, cfg_key): + """ Get a specific model default_cfg value by key. None if key doesn't exist. + """ + if model_name in _model_pretrained_cfgs: + return getattr(_model_pretrained_cfgs[model_name], cfg_key, None) + raise RuntimeError(f'No pretrained config exist for model {model_name}.') \ No newline at end of file diff --git a/timm/models/beit.py b/timm/models/beit.py index c44256a3..de71f441 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -61,12 +61,14 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from .helpers import build_model_with_cfg -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ -from .pretrained import generate_default_cfgs -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._pretrained import generate_default_cfgs +from ._registry import register_model from .vision_transformer import checkpoint_filter_fn +__all__ = ['Beit'] + def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 3815fa30..c67144cc 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -13,9 +13,9 @@ Consider all of the models definitions here as experimental WIP and likely to ch Hacked together by / copyright Ross Wightman, 2021. """ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from ._builder import build_model_with_cfg +from ._registry import register_model from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks -from .helpers import build_model_with_cfg -from .registry import register_model __all__ = [] diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 1e402629..0e5c9c7f 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -26,18 +26,18 @@ Hacked together by / copyright Ross Wightman, 2021. """ import math from dataclasses import dataclass, field, replace -from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence from functools import partial +from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\ - EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d -from .registry import register_model +from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ + create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] diff --git a/timm/models/cait.py b/timm/models/cait.py index c0892099..15dcd956 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -8,17 +8,16 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W """ # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ -from .registry import register_model - +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn'] diff --git a/timm/models/coat.py b/timm/models/coat.py index c3071a6c..4ed6d8e8 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -7,7 +7,6 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT Modified from timm/models/vision_transformer.py """ -from copy import deepcopy from functools import partial from typing import Tuple, List, Union @@ -16,19 +15,11 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ -from .registry import register_model -from .layers import _assert - - -__all__ = [ - "coat_tiny", - "coat_mini", - "coat_lite_tiny", - "coat_lite_mini", - "coat_lite_small" -] +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['CoaT'] def _cfg_coat(url='', **kwargs): diff --git a/timm/models/convit.py b/timm/models/convit.py index 26849f6e..d117ccdc 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -22,20 +22,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ''' +from functools import partial + import torch import torch.nn as nn -from functools import partial -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .vision_transformer_hybrid import HybridEmbed -from .fx_features import register_notrace_module -import torch -import torch.nn as nn + +__all__ = ['ConViT'] def _cfg(url='', **kwargs): diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index e7e2481a..3a8c6cf5 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -5,9 +5,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import SelectAdaptivePool2d +from timm.layers import SelectAdaptivePool2d +from ._registry import register_model +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq + +__all__ = ['ConvMixer'] def _cfg(url='', **kwargs): diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 36a484b3..eea5782a 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -18,12 +18,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ +from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ create_conv2d, get_act_layer, make_divisible, to_ntuple -from .pretrained import generate_default_cfgs -from .registry import register_model - +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 764eb3fe..908fcf6d 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -24,21 +24,22 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ +from functools import partial +from typing import List from typing import Tuple import torch -import torch.nn as nn -import torch.nn.functional as F import torch.hub -from functools import partial -from typing import List +import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, trunc_normal_, _assert -from .registry import register_model -from .vision_transformer import Mlp, Block +from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model +from .vision_transformer import Block + +__all__ = ['CrossViT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 2c09e7e3..280f929e 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -12,20 +12,18 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage Hacked together by / Copyright 2020 Ross Wightman """ -import collections.abc -from dataclasses import dataclass, field, asdict +from dataclasses import dataclass, asdict from functools import partial -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP -from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible -from .registry import register_model - +from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, MATCH_PREV_GROUP +from ._registry import register_model __all__ = ['CspNet'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/deit.py b/timm/models/deit.py index 3205b024..24fbbe56 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -17,9 +17,11 @@ from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model +__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 1afdfd7b..e731f7b0 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -4,7 +4,6 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool. """ import re from collections import OrderedDict -from functools import partial import torch import torch.nn as nn @@ -13,9 +12,10 @@ import torch.utils.checkpoint as cp from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, MATCH_PREV_GROUP -from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier -from .registry import register_model +from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import MATCH_PREV_GROUP +from ._registry import register_model __all__ = ['DenseNet'] diff --git a/timm/models/dla.py b/timm/models/dla.py index 0ab807c0..204fcb4b 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -13,9 +13,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['DLA'] diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 95159729..87bd918f 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -15,9 +15,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier -from .registry import register_model +from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['DPN'] diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 422d4f2c..d90471fb 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -8,20 +8,20 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt Modifications and additions for timm by / Copyright 2022, Ross Wightman """ import math -import torch from collections import OrderedDict from functools import partial from typing import Tuple -from torch import nn +import torch import torch.nn.functional as F +from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module -from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d -from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .registry import register_model - +from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model __all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 4749d93a..4f33f29a 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -18,9 +18,11 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, trunc_normal_, to_2tuple, Mlp -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 3c0efc96..a1324ae3 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -42,15 +42,15 @@ import torch import torch.nn as nn import torch.nn.functional as F - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ +from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq -from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct -from .registry import register_model +from ._features import FeatureInfo, FeatureHooks +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['EfficientNet', 'EfficientNetFeatures'] diff --git a/timm/models/factory.py b/timm/models/factory.py index 9e06c1aa..0ae83dc0 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -1,100 +1,4 @@ -import os -from typing import Any, Dict, Optional, Union -from urllib.parse import urlsplit +from ._factory import * -from .pretrained import PretrainedCfg, split_model_name_tag -from .helpers import load_checkpoint -from .hub import load_model_config_from_hf -from .layers import set_layer_config -from .registry import is_model, model_entrypoint - - -def parse_model_name(model_name): - if model_name.startswith('hf_hub'): - # NOTE for backwards compat, deprecate hf_hub use - model_name = model_name.replace('hf_hub', 'hf-hub') - parsed = urlsplit(model_name) - assert parsed.scheme in ('', 'timm', 'hf-hub') - if parsed.scheme == 'hf-hub': - # FIXME may use fragment as revision, currently `@` in URI path - return parsed.scheme, parsed.path - else: - model_name = os.path.split(parsed.path)[-1] - return 'timm', model_name - - -def safe_model_name(model_name, remove_source=True): - # return a filename / path safe model name - def make_safe(name): - return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') - if remove_source: - model_name = parse_model_name(model_name)[-1] - return make_safe(model_name) - - -def create_model( - model_name: str, - pretrained: bool = False, - pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, - pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, - checkpoint_path: str = '', - scriptable: Optional[bool] = None, - exportable: Optional[bool] = None, - no_jit: Optional[bool] = None, - **kwargs, -): - """Create a model - - Lookup model's entrypoint function and pass relevant args to create a new model. - - **kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg() - and then the model class __init__(). kwargs values set to None are pruned before passing. - - Args: - model_name (str): name of model to instantiate - pretrained (bool): load pretrained ImageNet-1k weights if true - pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model - pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these - checkpoint_path (str): path of checkpoint to load _after_ the model is initialized - scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) - exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) - no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) - - Keyword Args: - drop_rate (float): dropout rate for training (default: 0.0) - global_pool (str): global pool type (default: 'avg') - **: other kwargs are consumed by builder or model __init__() - """ - # Parameters that aren't supported by all models or are intended to only override model defaults if set - # should default to None in command line args/cfg. Remove them if they are present and not set so that - # non-supporting models don't break and default args remain in effect. - kwargs = {k: v for k, v in kwargs.items() if v is not None} - - model_source, model_name = parse_model_name(model_name) - if model_source == 'hf-hub': - assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.' - # For model names specified in the form `hf-hub:path/architecture_name@revision`, - # load model weights + pretrained_cfg from Hugging Face hub. - pretrained_cfg, model_name = load_model_config_from_hf(model_name) - else: - model_name, pretrained_tag = split_model_name_tag(model_name) - if not pretrained_cfg: - # a valid pretrained_cfg argument takes priority over tag in model name - pretrained_cfg = pretrained_tag - - if not is_model(model_name): - raise RuntimeError('Unknown model (%s)' % model_name) - - create_fn = model_entrypoint(model_name) - with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): - model = create_fn( - pretrained=pretrained, - pretrained_cfg=pretrained_cfg, - pretrained_cfg_overlay=pretrained_cfg_overlay, - **kwargs, - ) - - if checkpoint_path: - load_checkpoint(model, checkpoint_path) - - return model +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/features.py b/timm/models/features.py index 0bc46419..25605d99 100644 --- a/timm/models/features.py +++ b/timm/models/features.py @@ -1,284 +1,4 @@ -""" PyTorch Feature Extraction Helpers +from ._features import * -A collection of classes, functions, modules to help extract features from models -and provide a common interface for describing them. - -The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter -https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py - -Hacked together by / Copyright 2020 Ross Wightman -""" -from collections import OrderedDict, defaultdict -from copy import deepcopy -from functools import partial -from typing import Dict, List, Tuple - -import torch -import torch.nn as nn - - -class FeatureInfo: - - def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): - prev_reduction = 1 - for fi in feature_info: - # sanity check the mandatory fields, there may be additional fields depending on the model - assert 'num_chs' in fi and fi['num_chs'] > 0 - assert 'reduction' in fi and fi['reduction'] >= prev_reduction - prev_reduction = fi['reduction'] - assert 'module' in fi - self.out_indices = out_indices - self.info = feature_info - - def from_other(self, out_indices: Tuple[int]): - return FeatureInfo(deepcopy(self.info), out_indices) - - def get(self, key, idx=None): - """ Get value by key at specified index (indices) - if idx == None, returns value for key at each output index - if idx is an integer, return value for that feature module index (ignoring output indices) - if idx is a list/tupple, return value for each module index (ignoring output indices) - """ - if idx is None: - return [self.info[i][key] for i in self.out_indices] - if isinstance(idx, (tuple, list)): - return [self.info[i][key] for i in idx] - else: - return self.info[idx][key] - - def get_dicts(self, keys=None, idx=None): - """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) - """ - if idx is None: - if keys is None: - return [self.info[i] for i in self.out_indices] - else: - return [{k: self.info[i][k] for k in keys} for i in self.out_indices] - if isinstance(idx, (tuple, list)): - return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] - else: - return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} - - def channels(self, idx=None): - """ feature channels accessor - """ - return self.get('num_chs', idx) - - def reduction(self, idx=None): - """ feature reduction (output stride) accessor - """ - return self.get('reduction', idx) - - def module_name(self, idx=None): - """ feature module name accessor - """ - return self.get('module', idx) - - def __getitem__(self, item): - return self.info[item] - - def __len__(self): - return len(self.info) - - -class FeatureHooks: - """ Feature Hook Helper - - This module helps with the setup and extraction of hooks for extracting features from - internal nodes in a model by node name. This works quite well in eager Python but needs - redesign for torchscript. - """ - - def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): - # setup feature hooks - modules = {k: v for k, v in named_modules} - for i, h in enumerate(hooks): - hook_name = h['module'] - m = modules[hook_name] - hook_id = out_map[i] if out_map else hook_name - hook_fn = partial(self._collect_output_hook, hook_id) - hook_type = h.get('hook_type', default_hook_type) - if hook_type == 'forward_pre': - m.register_forward_pre_hook(hook_fn) - elif hook_type == 'forward': - m.register_forward_hook(hook_fn) - else: - assert False, "Unsupported hook type" - self._feature_outputs = defaultdict(OrderedDict) - - def _collect_output_hook(self, hook_id, *args): - x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre - if isinstance(x, tuple): - x = x[0] # unwrap input tuple - self._feature_outputs[x.device][hook_id] = x - - def get_output(self, device) -> Dict[str, torch.tensor]: - output = self._feature_outputs[device] - self._feature_outputs[device] = OrderedDict() # clear after reading - return output - - -def _module_list(module, flatten_sequential=False): - # a yield/iter would be better for this but wouldn't be compatible with torchscript - ml = [] - for name, module in module.named_children(): - if flatten_sequential and isinstance(module, nn.Sequential): - # first level of Sequential containers is flattened into containing model - for child_name, child_module in module.named_children(): - combined = [name, child_name] - ml.append(('_'.join(combined), '.'.join(combined), child_module)) - else: - ml.append((name, name, module)) - return ml - - -def _get_feature_info(net, out_indices): - feature_info = getattr(net, 'feature_info') - if isinstance(feature_info, FeatureInfo): - return feature_info.from_other(out_indices) - elif isinstance(feature_info, (list, tuple)): - return FeatureInfo(net.feature_info, out_indices) - else: - assert False, "Provided feature_info is not valid" - - -def _get_return_layers(feature_info, out_map): - module_names = feature_info.module_name() - return_layers = {} - for i, name in enumerate(module_names): - return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] - return return_layers - - -class FeatureDictNet(nn.ModuleDict): - """ Feature extractor with OrderedDict return - - Wrap a model and extract features as specified by the out indices, the network is - partially re-built from contained modules. - - There is a strong assumption that the modules have been registered into the model in the same - order as they are used. There should be no reuse of the same nn.Module more than once, including - trivial modules like `self.relu = nn.ReLU`. - - Only submodules that are directly assigned to the model class (`model.feature1`) or at most - one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. - All Sequential containers that are directly assigned to the original model will have their - modules assigned to this module with the name `model.features.1` being changed to `model.features_1` - - Arguments: - model (nn.Module): model from which we will extract the features - out_indices (tuple[int]): model output indices to extract features for - out_map (sequence): list or tuple specifying desired return id for each out index, - otherwise str(index) is used - feature_concat (bool): whether to concatenate intermediate features that are lists or tuples - vs select element [0] - flatten_sequential (bool): whether to flatten sequential modules assigned to model - """ - def __init__( - self, model, - out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): - super(FeatureDictNet, self).__init__() - self.feature_info = _get_feature_info(model, out_indices) - self.concat = feature_concat - self.return_layers = {} - return_layers = _get_return_layers(self.feature_info, out_map) - modules = _module_list(model, flatten_sequential=flatten_sequential) - remaining = set(return_layers.keys()) - layers = OrderedDict() - for new_name, old_name, module in modules: - layers[new_name] = module - if old_name in remaining: - # return id has to be consistently str type for torchscript - self.return_layers[new_name] = str(return_layers[old_name]) - remaining.remove(old_name) - if not remaining: - break - assert not remaining and len(self.return_layers) == len(return_layers), \ - f'Return layers ({remaining}) are not present in model' - self.update(layers) - - def _collect(self, x) -> (Dict[str, torch.Tensor]): - out = OrderedDict() - for name, module in self.items(): - x = module(x) - if name in self.return_layers: - out_id = self.return_layers[name] - if isinstance(x, (tuple, list)): - # If model tap is a tuple or list, concat or select first element - # FIXME this may need to be more generic / flexible for some nets - out[out_id] = torch.cat(x, 1) if self.concat else x[0] - else: - out[out_id] = x - return out - - def forward(self, x) -> Dict[str, torch.Tensor]: - return self._collect(x) - - -class FeatureListNet(FeatureDictNet): - """ Feature extractor with list return - - See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. - In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. - """ - def __init__( - self, model, - out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): - super(FeatureListNet, self).__init__( - model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, - flatten_sequential=flatten_sequential) - - def forward(self, x) -> (List[torch.Tensor]): - return list(self._collect(x).values()) - - -class FeatureHookNet(nn.ModuleDict): - """ FeatureHookNet - - Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. - - If `no_rewrite` is True, features are extracted via hooks without modifying the underlying - network in any way. - - If `no_rewrite` is False, the model will be re-written as in the - FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. - - FIXME this does not currently work with Torchscript, see FeatureHooks class - """ - def __init__( - self, model, - out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, - feature_concat=False, flatten_sequential=False, default_hook_type='forward'): - super(FeatureHookNet, self).__init__() - assert not torch.jit.is_scripting() - self.feature_info = _get_feature_info(model, out_indices) - self.out_as_dict = out_as_dict - layers = OrderedDict() - hooks = [] - if no_rewrite: - assert not flatten_sequential - if hasattr(model, 'reset_classifier'): # make sure classifier is removed? - model.reset_classifier(0) - layers['body'] = model - hooks.extend(self.feature_info.get_dicts()) - else: - modules = _module_list(model, flatten_sequential=flatten_sequential) - remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type - for f in self.feature_info.get_dicts()} - for new_name, old_name, module in modules: - layers[new_name] = module - for fn, fm in module.named_modules(prefix=old_name): - if fn in remaining: - hooks.append(dict(module=fn, hook_type=remaining[fn])) - del remaining[fn] - if not remaining: - break - assert not remaining, f'Return layers ({remaining}) are not present in model' - self.update(layers) - self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) - - def forward(self, x): - for name, module in self.items(): - x = module(x) - out = self.hooks.get_output(x.device) - return out if self.out_as_dict else list(out.values()) +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index b09381b7..0ff3a18b 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -1,106 +1,4 @@ -""" PyTorch FX Based Feature Extraction Helpers -Using https://pytorch.org/vision/stable/feature_extraction.html -""" -from typing import Callable, List, Dict, Union, Type +from ._features_fx import * -import torch -from torch import nn - -from .features import _get_feature_info - -try: - from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor - has_fx_feature_extraction = True -except ImportError: - has_fx_feature_extraction = False - -# Layers we went to treat as leaf modules -from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame -from .layers.non_local_attn import BilinearAttnTransform -from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame - -# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here -# BUT modules from timm.models should use the registration mechanism below -_leaf_modules = { - BilinearAttnTransform, # reason: flow control t <= 1 - # Reason: get_same_padding has a max which raises a control flow error - Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, - CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) -} - -try: - from .layers import InplaceAbn - _leaf_modules.add(InplaceAbn) -except ImportError: - pass - - -def register_notrace_module(module: Type[nn.Module]): - """ - Any module not under timm.models.layers should get this decorator if we don't want to trace through it. - """ - _leaf_modules.add(module) - return module - - -# Functions we want to autowrap (treat them as leaves) -_autowrap_functions = set() - - -def register_notrace_function(func: Callable): - """ - Decorator for functions which ought not to be traced through - """ - _autowrap_functions.add(func) - return func - - -def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): - assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' - return _create_feature_extractor( - model, return_nodes, - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} - ) - - -class FeatureGraphNet(nn.Module): - """ A FX Graph based feature extractor that works with the model feature_info metadata - """ - def __init__(self, model, out_indices, out_map=None): - super().__init__() - assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' - self.feature_info = _get_feature_info(model, out_indices) - if out_map is not None: - assert len(out_map) == len(out_indices) - return_nodes = { - info['module']: out_map[i] if out_map is not None else info['module'] - for i, info in enumerate(self.feature_info) if i in out_indices} - self.graph_module = create_feature_extractor(model, return_nodes) - - def forward(self, x): - return list(self.graph_module(x).values()) - - -class GraphExtractNet(nn.Module): - """ A standalone feature extraction wrapper that maps dict -> list or single tensor - NOTE: - * one can use feature_extractor directly if dictionary output is desired - * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info - metadata for builtin feature extraction mode - * create_feature_extractor can be used directly if dictionary output is acceptable - - Args: - model: model to extract features from - return_nodes: node names to return features from (dict or list) - squeeze_out: if only one output, and output in list format, flatten to single tensor - """ - def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): - super().__init__() - self.squeeze_out = squeeze_out - self.graph_module = create_feature_extractor(model, return_nodes) - - def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: - out = list(self.graph_module(x).values()) - if self.squeeze_out and len(out) == 1: - return out[0] - return out +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index fb375e2c..ec9b7e5e 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -28,12 +28,13 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\ +from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ get_attn, get_act_layer, get_norm_layer, _assert -from .registry import register_model -from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply +from ._registry import register_model +from .vision_transformer_relpos import RelPosBias # FIXME move to common location __all__ = ['GlobalContextVit'] diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index e19af88b..492049b9 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -11,13 +11,12 @@ import torch import torch.nn as nn import torch.nn.functional as F - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import SelectAdaptivePool2d, Linear, make_divisible -from .efficientnet_blocks import SqueezeExcite, ConvBnAct -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model - +from timm.layers import SelectAdaptivePool2d, Linear, make_divisible +from ._builder import build_model_with_cfg +from ._efficientnet_blocks import SqueezeExcite, ConvBnAct +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['GhostNet'] diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index a1e73554..2b4131fb 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -5,11 +5,13 @@ by Ross Wightman """ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SEModule -from .registry import register_model +from timm.layers import SEModule +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet, Bottleneck, BasicBlock +__all__ = [] + def _cfg(url='', **kwargs): return { diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index a9c946b2..b487d0fd 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -13,9 +13,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier, get_padding -from .registry import register_model +from timm.layers import create_classifier, get_padding +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['Xception65'] diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py index 132eeab4..d77e642a 100644 --- a/timm/models/hardcorenas.py +++ b/timm/models/hardcorenas.py @@ -3,12 +3,14 @@ from functools import partial import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels -from .helpers import build_model_with_cfg, pretrained_cfg_for_features -from .layers import get_act_fn +from ._builder import build_model_with_cfg +from ._builder import pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels +from ._registry import register_model from .mobilenetv3 import MobileNetV3, MobileNetV3Features -from .registry import register_model + +__all__ = [] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 2a5551e0..6bc82eb8 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -1,855 +1,7 @@ -""" Model creation / weight loading / state_dict helpers +from ._builder import * +from ._helpers import * +from ._manipulate import * +from ._prune import * -Hacked together by / Copyright 2020 Ross Wightman -""" -import collections.abc -import dataclasses -import logging -import math -import os -import re -from collections import OrderedDict, defaultdict -from copy import deepcopy -from itertools import chain -from typing import Any, Callable, Optional, Tuple, Dict, Union - -import torch -import torch.nn as nn -from torch.hub import load_state_dict_from_url -from torch.utils.checkpoint import checkpoint - -from .pretrained import PretrainedCfg -from .features import FeatureListNet, FeatureDictNet, FeatureHookNet -from .fx_features import FeatureGraphNet -from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf -from .layers import Conv2dSame, Linear, BatchNormAct2d -from .registry import get_pretrained_cfg - - -_logger = logging.getLogger(__name__) - - -# Global variables for rarely used pretrained checkpoint download progress and hash check. -# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle. -_DOWNLOAD_PROGRESS = False -_CHECK_HASH = False - - -def clean_state_dict(state_dict): - # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training - cleaned_state_dict = OrderedDict() - for k, v in state_dict.items(): - name = k[7:] if k.startswith('module.') else k - cleaned_state_dict[name] = v - return cleaned_state_dict - - -def load_state_dict(checkpoint_path, use_ema=True): - if checkpoint_path and os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') - state_dict_key = '' - if isinstance(checkpoint, dict): - if use_ema and checkpoint.get('state_dict_ema', None) is not None: - state_dict_key = 'state_dict_ema' - elif use_ema and checkpoint.get('model_ema', None) is not None: - state_dict_key = 'model_ema' - elif 'state_dict' in checkpoint: - state_dict_key = 'state_dict' - elif 'model' in checkpoint: - state_dict_key = 'model' - state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint) - _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) - return state_dict - else: - _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) - raise FileNotFoundError() - - -def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False): - if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): - # numpy checkpoint, try to load via model specific load_pretrained fn - if hasattr(model, 'load_pretrained'): - model.load_pretrained(checkpoint_path) - else: - raise NotImplementedError('Model cannot load numpy checkpoint') - return - state_dict = load_state_dict(checkpoint_path, use_ema) - if remap: - state_dict = remap_checkpoint(model, state_dict) - incompatible_keys = model.load_state_dict(state_dict, strict=strict) - return incompatible_keys - - -def remap_checkpoint(model, state_dict, allow_reshape=True): - """ remap checkpoint by iterating over state dicts in order (ignoring original keys). - This assumes models (and originating state dict) were created with params registered in same order. - """ - out_dict = {} - for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): - assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' - if va.shape != vb.shape: - if allow_reshape: - vb = vb.reshape(va.shape) - else: - assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' - out_dict[ka] = vb - return out_dict - - -def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): - resume_epoch = None - if os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - if log_info: - _logger.info('Restoring model state from checkpoint...') - state_dict = clean_state_dict(checkpoint['state_dict']) - model.load_state_dict(state_dict) - - if optimizer is not None and 'optimizer' in checkpoint: - if log_info: - _logger.info('Restoring optimizer state from checkpoint...') - optimizer.load_state_dict(checkpoint['optimizer']) - - if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: - if log_info: - _logger.info('Restoring AMP loss scaler state from checkpoint...') - loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) - - if 'epoch' in checkpoint: - resume_epoch = checkpoint['epoch'] - if 'version' in checkpoint and checkpoint['version'] > 1: - resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save - - if log_info: - _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) - else: - model.load_state_dict(checkpoint) - if log_info: - _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) - return resume_epoch - else: - _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) - raise FileNotFoundError() - - -def _resolve_pretrained_source(pretrained_cfg): - cfg_source = pretrained_cfg.get('source', '') - pretrained_url = pretrained_cfg.get('url', None) - pretrained_file = pretrained_cfg.get('file', None) - hf_hub_id = pretrained_cfg.get('hf_hub_id', None) - # resolve where to load pretrained weights from - load_from = '' - pretrained_loc = '' - if cfg_source == 'hf-hub' and has_hf_hub(necessary=True): - # hf-hub specified as source via model identifier - load_from = 'hf-hub' - assert hf_hub_id - pretrained_loc = hf_hub_id - else: - # default source == timm or unspecified - if pretrained_file: - load_from = 'file' - pretrained_loc = pretrained_file - elif pretrained_url: - load_from = 'url' - pretrained_loc = pretrained_url - elif hf_hub_id and has_hf_hub(necessary=True): - # hf-hub available as alternate weight source in default_cfg - load_from = 'hf-hub' - pretrained_loc = hf_hub_id - if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None): - # if a filename override is set, return tuple for location w/ (hub_id, filename) - pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] - return load_from, pretrained_loc - - -def set_pretrained_download_progress(enable=True): - """ Set download progress for pretrained weights on/off (globally). """ - global _DOWNLOAD_PROGRESS - _DOWNLOAD_PROGRESS = enable - - -def set_pretrained_check_hash(enable=True): - """ Set hash checking for pretrained weights on/off (globally). """ - global _CHECK_HASH - _CHECK_HASH = enable - - -def load_custom_pretrained( - model: nn.Module, - pretrained_cfg: Optional[Dict] = None, - load_fn: Optional[Callable] = None, -): - r"""Loads a custom (read non .pth) weight file - - Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls - a passed in custom load fun, or the `load_pretrained` model member fn. - - If the object is already present in `model_dir`, it's deserialized and returned. - The default value of `model_dir` is ``/checkpoints`` where - `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. - - Args: - model: The instantiated model to load weights into - pretrained_cfg (dict): Default pretrained model cfg - load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named - 'laod_pretrained' on the model will be called if it exists - """ - pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) - if not pretrained_cfg: - _logger.warning("Invalid pretrained config, cannot load weights.") - return - - load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) - if not load_from: - _logger.warning("No pretrained weights exist for this model. Using random initialization.") - return - if load_from == 'hf-hub': # FIXME - _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") - elif load_from == 'url': - pretrained_loc = download_cached_file( - pretrained_loc, - check_hash=_CHECK_HASH, - progress=_DOWNLOAD_PROGRESS - ) - - if load_fn is not None: - load_fn(model, pretrained_loc) - elif hasattr(model, 'load_pretrained'): - model.load_pretrained(pretrained_loc) - else: - _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") - - -def adapt_input_conv(in_chans, conv_weight): - conv_type = conv_weight.dtype - conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU - O, I, J, K = conv_weight.shape - if in_chans == 1: - if I > 3: - assert conv_weight.shape[1] % 3 == 0 - # For models with space2depth stems - conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) - conv_weight = conv_weight.sum(dim=2, keepdim=False) - else: - conv_weight = conv_weight.sum(dim=1, keepdim=True) - elif in_chans != 3: - if I != 3: - raise NotImplementedError('Weight format not supported by conversion.') - else: - # NOTE this strategy should be better than random init, but there could be other combinations of - # the original RGB input layer weights that'd work better for specific cases. - repeat = int(math.ceil(in_chans / 3)) - conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] - conv_weight *= (3 / float(in_chans)) - conv_weight = conv_weight.to(conv_type) - return conv_weight - - -def load_pretrained( - model: nn.Module, - pretrained_cfg: Optional[Dict] = None, - num_classes: int = 1000, - in_chans: int = 3, - filter_fn: Optional[Callable] = None, - strict: bool = True, -): - """ Load pretrained checkpoint - - Args: - model (nn.Module) : PyTorch model module - pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset - num_classes (int): num_classes for target model - in_chans (int): in_chans for target model - filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) - strict (bool): strict load of checkpoint - - """ - pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) - if not pretrained_cfg: - _logger.warning("Invalid pretrained config, cannot load weights.") - return - - load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) - if load_from == 'file': - _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') - state_dict = load_state_dict(pretrained_loc) - elif load_from == 'url': - _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') - state_dict = load_state_dict_from_url( - pretrained_loc, - map_location='cpu', - progress=_DOWNLOAD_PROGRESS, - check_hash=_CHECK_HASH, - ) - elif load_from == 'hf-hub': - _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') - if isinstance(pretrained_loc, (list, tuple)): - state_dict = load_state_dict_from_hf(*pretrained_loc) - else: - state_dict = load_state_dict_from_hf(pretrained_loc) - else: - _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") - return - - if filter_fn is not None: - # for backwards compat with filter fn that take one arg, try one first, the two - try: - state_dict = filter_fn(state_dict) - except TypeError: - state_dict = filter_fn(state_dict, model) - - input_convs = pretrained_cfg.get('first_conv', None) - if input_convs is not None and in_chans != 3: - if isinstance(input_convs, str): - input_convs = (input_convs,) - for input_conv_name in input_convs: - weight_name = input_conv_name + '.weight' - try: - state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) - _logger.info( - f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') - except NotImplementedError as e: - del state_dict[weight_name] - strict = False - _logger.warning( - f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') - - classifiers = pretrained_cfg.get('classifier', None) - label_offset = pretrained_cfg.get('label_offset', 0) - if classifiers is not None: - if isinstance(classifiers, str): - classifiers = (classifiers,) - if num_classes != pretrained_cfg['num_classes']: - for classifier_name in classifiers: - # completely discard fully connected if model num_classes doesn't match pretrained weights - state_dict.pop(classifier_name + '.weight', None) - state_dict.pop(classifier_name + '.bias', None) - strict = False - elif label_offset > 0: - for classifier_name in classifiers: - # special case for pretrained weights with an extra background class in pretrained weights - classifier_weight = state_dict[classifier_name + '.weight'] - state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] - classifier_bias = state_dict[classifier_name + '.bias'] - state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] - - model.load_state_dict(state_dict, strict=strict) - - -def extract_layer(model, layer): - layer = layer.split('.') - module = model - if hasattr(model, 'module') and layer[0] != 'module': - module = model.module - if not hasattr(model, 'module') and layer[0] == 'module': - layer = layer[1:] - for l in layer: - if hasattr(module, l): - if not l.isdigit(): - module = getattr(module, l) - else: - module = module[int(l)] - else: - return module - return module - - -def set_layer(model, layer, val): - layer = layer.split('.') - module = model - if hasattr(model, 'module') and layer[0] != 'module': - module = model.module - lst_index = 0 - module2 = module - for l in layer: - if hasattr(module2, l): - if not l.isdigit(): - module2 = getattr(module2, l) - else: - module2 = module2[int(l)] - lst_index += 1 - lst_index -= 1 - for l in layer[:lst_index]: - if not l.isdigit(): - module = getattr(module, l) - else: - module = module[int(l)] - l = layer[lst_index] - setattr(module, l, val) - - -def adapt_model_from_string(parent_module, model_string): - separator = '***' - state_dict = {} - lst_shape = model_string.split(separator) - for k in lst_shape: - k = k.split(':') - key = k[0] - shape = k[1][1:-1].split(',') - if shape[0] != '': - state_dict[key] = [int(i) for i in shape] - - new_module = deepcopy(parent_module) - for n, m in parent_module.named_modules(): - old_module = extract_layer(parent_module, n) - if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): - if isinstance(old_module, Conv2dSame): - conv = Conv2dSame - else: - conv = nn.Conv2d - s = state_dict[n + '.weight'] - in_channels = s[1] - out_channels = s[0] - g = 1 - if old_module.groups > 1: - in_channels = out_channels - g = in_channels - new_conv = conv( - in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, - bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, - groups=g, stride=old_module.stride) - set_layer(new_module, n, new_conv) - elif isinstance(old_module, BatchNormAct2d): - new_bn = BatchNormAct2d( - state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, - affine=old_module.affine, track_running_stats=True) - new_bn.drop = old_module.drop - new_bn.act = old_module.act - set_layer(new_module, n, new_bn) - elif isinstance(old_module, nn.BatchNorm2d): - new_bn = nn.BatchNorm2d( - num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, - affine=old_module.affine, track_running_stats=True) - set_layer(new_module, n, new_bn) - elif isinstance(old_module, nn.Linear): - # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? - num_features = state_dict[n + '.weight'][1] - new_fc = Linear( - in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) - set_layer(new_module, n, new_fc) - if hasattr(new_module, 'num_features'): - new_module.num_features = num_features - new_module.eval() - parent_module.eval() - - return new_module - - -def adapt_model_from_file(parent_module, model_variant): - adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') - with open(adapt_file, 'r') as f: - return adapt_model_from_string(parent_module, f.read().strip()) - - -def pretrained_cfg_for_features(pretrained_cfg): - pretrained_cfg = deepcopy(pretrained_cfg) - # remove default pretrained cfg fields that don't have much relevance for feature backbone - to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size? - for tr in to_remove: - pretrained_cfg.pop(tr, None) - return pretrained_cfg - - -def _filter_kwargs(kwargs, names): - if not kwargs or not names: - return - for n in names: - kwargs.pop(n, None) - - -def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter): - """ Update the default_cfg and kwargs before passing to model - - Args: - pretrained_cfg: input pretrained cfg (updated in-place) - kwargs: keyword args passed to model build fn (updated in-place) - kwargs_filter: keyword arg keys that must be removed before model __init__ - """ - # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) - default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') - if pretrained_cfg.get('fixed_input_size', False): - # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size - default_kwarg_names += ('img_size',) - - for n in default_kwarg_names: - # for legacy reasons, model __init__args uses img_size + in_chans as separate args while - # pretrained_cfg has one input_size=(C, H ,W) entry - if n == 'img_size': - input_size = pretrained_cfg.get('input_size', None) - if input_size is not None: - assert len(input_size) == 3 - kwargs.setdefault(n, input_size[-2:]) - elif n == 'in_chans': - input_size = pretrained_cfg.get('input_size', None) - if input_size is not None: - assert len(input_size) == 3 - kwargs.setdefault(n, input_size[0]) - else: - default_val = pretrained_cfg.get(n, None) - if default_val is not None: - kwargs.setdefault(n, pretrained_cfg[n]) - - # Filter keyword args for task specific model variants (some 'features only' models, etc.) - _filter_kwargs(kwargs, names=kwargs_filter) - - -def resolve_pretrained_cfg( - variant: str, - pretrained_cfg=None, - pretrained_cfg_overlay=None, -) -> PretrainedCfg: - model_with_tag = variant - pretrained_tag = None - if pretrained_cfg: - if isinstance(pretrained_cfg, dict): - # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg - pretrained_cfg = PretrainedCfg(**pretrained_cfg) - elif isinstance(pretrained_cfg, str): - pretrained_tag = pretrained_cfg - pretrained_cfg = None - - # fallback to looking up pretrained cfg in model registry by variant identifier - if not pretrained_cfg: - if pretrained_tag: - model_with_tag = '.'.join([variant, pretrained_tag]) - pretrained_cfg = get_pretrained_cfg(model_with_tag) - - if not pretrained_cfg: - _logger.warning( - f"No pretrained configuration specified for {model_with_tag} model. Using a default." - f" Please add a config to the model pretrained_cfg registry or pass explicitly.") - pretrained_cfg = PretrainedCfg() # instance with defaults - - pretrained_cfg_overlay = pretrained_cfg_overlay or {} - if not pretrained_cfg.architecture: - pretrained_cfg_overlay.setdefault('architecture', variant) - pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay) - - return pretrained_cfg - - -def build_model_with_cfg( - model_cls: Callable, - variant: str, - pretrained: bool, - pretrained_cfg: Optional[Dict] = None, - pretrained_cfg_overlay: Optional[Dict] = None, - model_cfg: Optional[Any] = None, - feature_cfg: Optional[Dict] = None, - pretrained_strict: bool = True, - pretrained_filter_fn: Optional[Callable] = None, - kwargs_filter: Optional[Tuple[str]] = None, - **kwargs, -): - """ Build model with specified default_cfg and optional model_cfg - - This helper fn aids in the construction of a model including: - * handling default_cfg and associated pretrained weight loading - * passing through optional model_cfg for models with config based arch spec - * features_only model adaptation - * pruning config / model adaptation - - Args: - model_cls (nn.Module): model class - variant (str): model variant name - pretrained (bool): load pretrained weights - pretrained_cfg (dict): model's pretrained weight/task config - model_cfg (Optional[Dict]): model's architecture config - feature_cfg (Optional[Dict]: feature extraction adapter config - pretrained_strict (bool): load pretrained weights strictly - pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights - kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model - **kwargs: model args passed through to model __init__ - """ - pruned = kwargs.pop('pruned', False) - features = False - feature_cfg = feature_cfg or {} - - # resolve and update model pretrained config and model kwargs - pretrained_cfg = resolve_pretrained_cfg( - variant, - pretrained_cfg=pretrained_cfg, - pretrained_cfg_overlay=pretrained_cfg_overlay - ) - - # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model - pretrained_cfg = pretrained_cfg.to_dict() - - _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) - - # Setup for feature extraction wrapper done at end of this fn - if kwargs.pop('features_only', False): - features = True - feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) - if 'out_indices' in kwargs: - feature_cfg['out_indices'] = kwargs.pop('out_indices') - - # Instantiate the model - if model_cfg is None: - model = model_cls(**kwargs) - else: - model = model_cls(cfg=model_cfg, **kwargs) - model.pretrained_cfg = pretrained_cfg - model.default_cfg = model.pretrained_cfg # alias for backwards compat - - if pruned: - model = adapt_model_from_file(model, variant) - - # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats - num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) - if pretrained: - if pretrained_cfg.get('custom_load', False): - load_custom_pretrained( - model, - pretrained_cfg=pretrained_cfg, - ) - else: - load_pretrained( - model, - pretrained_cfg=pretrained_cfg, - num_classes=num_classes_pretrained, - in_chans=kwargs.get('in_chans', 3), - filter_fn=pretrained_filter_fn, - strict=pretrained_strict, - ) - - # Wrap the model in a feature extraction module if enabled - if features: - feature_cls = FeatureListNet - if 'feature_cls' in feature_cfg: - feature_cls = feature_cfg.pop('feature_cls') - if isinstance(feature_cls, str): - feature_cls = feature_cls.lower() - if 'hook' in feature_cls: - feature_cls = FeatureHookNet - elif feature_cls == 'fx': - feature_cls = FeatureGraphNet - else: - assert False, f'Unknown feature class {feature_cls}' - model = feature_cls(model, **feature_cfg) - model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg - model.default_cfg = model.pretrained_cfg # alias for backwards compat - - return model - - -def model_parameters(model, exclude_head=False): - if exclude_head: - # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering - return [p for p in model.parameters()][:-2] - else: - return model.parameters() - - -def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: - if not depth_first and include_root: - fn(module=module, name=name) - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - fn(module=module, name=name) - return module - - -def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): - if not depth_first and include_root: - yield name, module - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - yield from named_modules( - module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - yield name, module - - -def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False): - if module._parameters and not depth_first and include_root: - yield name, module - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - yield from named_modules_with_params( - module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if module._parameters and depth_first and include_root: - yield name, module - - -MATCH_PREV_GROUP = (99999,) - - -def group_with_matcher( - named_objects, - group_matcher: Union[Dict, Callable], - output_values: bool = False, - reverse: bool = False -): - if isinstance(group_matcher, dict): - # dictionary matcher contains a dict of raw-string regex expr that must be compiled - compiled = [] - for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): - if mspec is None: - continue - # map all matching specifications into 3-tuple (compiled re, prefix, suffix) - if isinstance(mspec, (tuple, list)): - # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) - for sspec in mspec: - compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] - else: - compiled += [(re.compile(mspec), (group_ordinal,), None)] - group_matcher = compiled - - def _get_grouping(name): - if isinstance(group_matcher, (list, tuple)): - for match_fn, prefix, suffix in group_matcher: - r = match_fn.match(name) - if r: - parts = (prefix, r.groups(), suffix) - # map all tuple elem to int for numeric sort, filter out None entries - return tuple(map(float, chain.from_iterable(filter(None, parts)))) - return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal - else: - ord = group_matcher(name) - if not isinstance(ord, collections.abc.Iterable): - return ord, - return tuple(ord) - - # map layers into groups via ordinals (ints or tuples of ints) from matcher - grouping = defaultdict(list) - for k, v in named_objects: - grouping[_get_grouping(k)].append(v if output_values else k) - - # remap to integers - layer_id_to_param = defaultdict(list) - lid = -1 - for k in sorted(filter(lambda x: x is not None, grouping.keys())): - if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: - lid += 1 - layer_id_to_param[lid].extend(grouping[k]) - - if reverse: - assert not output_values, "reverse mapping only sensible for name output" - # output reverse mapping - param_to_layer_id = {} - for lid, lm in layer_id_to_param.items(): - for n in lm: - param_to_layer_id[n] = lid - return param_to_layer_id - - return layer_id_to_param - - -def group_parameters( - module: nn.Module, - group_matcher, - output_values=False, - reverse=False, -): - return group_with_matcher( - module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) - - -def group_modules( - module: nn.Module, - group_matcher, - output_values=False, - reverse=False, -): - return group_with_matcher( - named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) - - -def checkpoint_seq( - functions, - x, - every=1, - flatten=False, - skip_last=False, - preserve_rng_state=True -): - r"""A helper function for checkpointing sequential models. - - Sequential models execute a list of modules/functions in order - (sequentially). Therefore, we can divide such a sequence into segments - and checkpoint each segment. All segments except run in :func:`torch.no_grad` - manner, i.e., not storing the intermediate activations. The inputs of each - checkpointed segment will be saved for re-running the segment in the backward pass. - - See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. - - .. warning:: - Checkpointing currently only supports :func:`torch.autograd.backward` - and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` - is not supported. - - .. warning: - At least one of the inputs needs to have :code:`requires_grad=True` if - grads are needed for model inputs, otherwise the checkpointed part of the - model won't have gradients. - - Args: - functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. - x: A Tensor that is input to :attr:`functions` - every: checkpoint every-n functions (default: 1) - flatten (bool): flatten nn.Sequential of nn.Sequentials - skip_last (bool): skip checkpointing the last function in the sequence if True - preserve_rng_state (bool, optional, default=True): Omit stashing and restoring - the RNG state during each checkpoint. - - Returns: - Output of running :attr:`functions` sequentially on :attr:`*inputs` - - Example: - >>> model = nn.Sequential(...) - >>> input_var = checkpoint_seq(model, input_var, every=2) - """ - def run_function(start, end, functions): - def forward(_x): - for j in range(start, end + 1): - _x = functions[j](_x) - return _x - return forward - - if isinstance(functions, torch.nn.Sequential): - functions = functions.children() - if flatten: - functions = chain.from_iterable(functions) - if not isinstance(functions, (tuple, list)): - functions = tuple(functions) - - num_checkpointed = len(functions) - if skip_last: - num_checkpointed -= 1 - end = -1 - for start in range(0, num_checkpointed, every): - end = min(start + every - 1, num_checkpointed - 1) - x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) - if skip_last: - return run_function(end + 1, len(functions) - 1, functions)(x) - return x - - -def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'): - prefix_is_tuple = isinstance(prefix, tuple) - if isinstance(module_types, str): - if module_types == 'container': - module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict) - else: - module_types = (nn.Sequential,) - for name, module in named_modules: - if depth and isinstance(module, module_types): - yield from flatten_modules( - module.named_children(), - depth - 1, - prefix=(name,) if prefix_is_tuple else name, - module_types=module_types, - ) - else: - if prefix_is_tuple: - name = prefix + (name,) - yield name, module - else: - if prefix: - name = '.'.join([prefix, name]) - yield name, module +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 30860120..338d409e 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -16,12 +16,14 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .features import FeatureInfo -from .helpers import build_model_with_cfg, pretrained_cfg_for_features -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._features import FeatureInfo +from ._registry import register_model from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE +__all__ = ['HighResolutionNet', 'HighResolutionNetFeatures'] # model_registry will add each entrypoint fn to this + _BN_MOMENTUM = 0.1 _logger = logging.getLogger(__name__) diff --git a/timm/models/hub.py b/timm/models/hub.py index 18c5444a..074ca025 100644 --- a/timm/models/hub.py +++ b/timm/models/hub.py @@ -1,217 +1,4 @@ -import json -import logging -import os -from functools import partial -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Optional, Union +from _hub import * -import torch -from torch.hub import HASH_REGEX, download_url_to_file, urlparse - -try: - from torch.hub import get_dir -except ImportError: - from torch.hub import _get_torch_home as get_dir - -from timm import __version__ -from timm.models.pretrained import filter_pretrained_cfg - -try: - from huggingface_hub import ( - create_repo, get_hf_file_metadata, - hf_hub_download, hf_hub_url, - repo_type_and_id_from_hf_id, upload_folder) - from huggingface_hub.utils import EntryNotFoundError - hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__) - _has_hf_hub = True -except ImportError: - hf_hub_download = None - _has_hf_hub = False - -_logger = logging.getLogger(__name__) - - -def get_cache_dir(child_dir=''): - """ - Returns the location of the directory where models are cached (and creates it if necessary). - """ - # Issue warning to move data if old env is set - if os.getenv('TORCH_MODEL_ZOO'): - _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') - - hub_dir = get_dir() - child_dir = () if not child_dir else (child_dir,) - model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) - os.makedirs(model_dir, exist_ok=True) - return model_dir - - -def download_cached_file(url, check_hash=True, progress=False): - if isinstance(url, (list, tuple)): - url, filename = url - else: - parts = urlparse(url) - filename = os.path.basename(parts.path) - cached_file = os.path.join(get_cache_dir(), filename) - if not os.path.exists(cached_file): - _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) - hash_prefix = None - if check_hash: - r = HASH_REGEX.search(filename) # r is Optional[Match[str]] - hash_prefix = r.group(1) if r else None - download_url_to_file(url, cached_file, hash_prefix, progress=progress) - return cached_file - - -def has_hf_hub(necessary=False): - if not _has_hf_hub and necessary: - # if no HF Hub module installed, and it is necessary to continue, raise error - raise RuntimeError( - 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') - return _has_hf_hub - - -def hf_split(hf_id): - # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme - rev_split = hf_id.split('@') - assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' - hf_model_id = rev_split[0] - hf_revision = rev_split[-1] if len(rev_split) > 1 else None - return hf_model_id, hf_revision - - -def load_cfg_from_json(json_file: Union[str, os.PathLike]): - with open(json_file, "r", encoding="utf-8") as reader: - text = reader.read() - return json.loads(text) - - -def _download_from_hf(model_id: str, filename: str): - hf_model_id, hf_revision = hf_split(model_id) - return hf_hub_download(hf_model_id, filename, revision=hf_revision) - - -def load_model_config_from_hf(model_id: str): - assert has_hf_hub(True) - cached_file = _download_from_hf(model_id, 'config.json') - - hf_config = load_cfg_from_json(cached_file) - if 'pretrained_cfg' not in hf_config: - # old form, pull pretrain_cfg out of the base dict - pretrained_cfg = hf_config - hf_config = {} - hf_config['architecture'] = pretrained_cfg.pop('architecture') - hf_config['num_features'] = pretrained_cfg.pop('num_features', None) - if 'labels' in pretrained_cfg: - hf_config['label_name'] = pretrained_cfg.pop('labels') - hf_config['pretrained_cfg'] = pretrained_cfg - - # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now - pretrained_cfg = hf_config['pretrained_cfg'] - pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation - pretrained_cfg['source'] = 'hf-hub' - if 'num_classes' in hf_config: - # model should be created with parent num_classes if they exist - pretrained_cfg['num_classes'] = hf_config['num_classes'] - model_name = hf_config['architecture'] - - return pretrained_cfg, model_name - - -def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): - assert has_hf_hub(True) - cached_file = _download_from_hf(model_id, filename) - state_dict = torch.load(cached_file, map_location='cpu') - return state_dict - - -def save_for_hf(model, save_directory, model_config=None): - assert has_hf_hub(True) - model_config = model_config or {} - save_directory = Path(save_directory) - save_directory.mkdir(exist_ok=True, parents=True) - - weights_path = save_directory / 'pytorch_model.bin' - torch.save(model.state_dict(), weights_path) - - config_path = save_directory / 'config.json' - hf_config = {} - pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) - # set some values at root config level - hf_config['architecture'] = pretrained_cfg.pop('architecture') - hf_config['num_classes'] = model_config.get('num_classes', model.num_classes) - hf_config['num_features'] = model_config.get('num_features', model.num_features) - hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None)) - - if 'label' in model_config: - _logger.warning( - "'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. " - "Using provided 'label' field as 'label_name'.") - model_config['label_name'] = model_config.pop('label') - - label_name = model_config.pop('label_name', None) - if label_name: - assert isinstance(label_name, (dict, list, tuple)) - # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages) - # can be a dict id: name if there are id gaps, or tuple/list if no gaps. - hf_config['label_name'] = model_config['label_name'] - - display_name = model_config.pop('display_name', None) - if display_name: - assert isinstance(display_name, dict) - # map label_name -> user interface display name - hf_config['display_name'] = model_config['display_name'] - - hf_config['pretrained_cfg'] = pretrained_cfg - hf_config.update(model_config) - - with config_path.open('w') as f: - json.dump(hf_config, f, indent=2) - - -def push_to_hf_hub( - model, - repo_id: str, - commit_message: str = 'Add model', - token: Optional[str] = None, - revision: Optional[str] = None, - private: bool = False, - create_pr: bool = False, - model_config: Optional[dict] = None, -): - # Create repo if it doesn't exist yet - repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) - - # Infer complete repo_id from repo_url - # Can be different from the input `repo_id` if repo_owner was implicit - _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) - repo_id = f"{repo_owner}/{repo_name}" - - # Check if README file already exist in repo - try: - get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) - has_readme = True - except EntryNotFoundError: - has_readme = False - - # Dump model and push to Hub - with TemporaryDirectory() as tmpdir: - # Save model weights and config. - save_for_hf(model, tmpdir, model_config=model_config) - - # Add readme if it does not exist - if not has_readme: - model_name = repo_id.split('/')[-1] - readme_path = Path(tmpdir) / "README.md" - readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}' - readme_path.write_text(readme_text) - - # Upload model and return - return upload_folder( - repo_id=repo_id, - folder_path=tmpdir, - revision=revision, - create_pr=create_pr, - commit_message=commit_message, - ) +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index fa7b8ec8..3006f3d2 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -7,9 +7,10 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, flatten_modules -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import flatten_modules +from ._registry import register_model __all__ = ['InceptionResnetV2'] diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index c70bd608..28794ce6 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -8,9 +8,13 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, flatten_modules -from .registry import register_model -from .layers import trunc_normal_, create_classifier, Linear +from timm.layers import trunc_normal_, create_classifier, Linear +from ._builder import build_model_with_cfg +from ._builder import resolve_pretrained_cfg +from ._manipulate import flatten_modules +from ._registry import register_model + +__all__ = ['InceptionV3', 'InceptionV3Aux'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 5f4e208f..c1559829 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -7,9 +7,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['InceptionV4'] diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 21c641b6..97e70563 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1,44 +1,48 @@ -from .activations import * -from .adaptive_avgmax_pool import \ +# NOTE timm.models.layers is DEPRECATED, please use timm.layers, this is here to reduce breakages in transition +from timm.layers.activations import * +from timm.layers.adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d -from .blur_pool import BlurPool2d -from .classifier import ClassifierHead, create_classifier -from .cond_conv2d import CondConv2d, get_condconv_initializer -from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ +from timm.layers.blur_pool import BlurPool2d +from timm.layers.classifier import ClassifierHead, create_classifier +from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer +from timm.layers.config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ set_layer_config -from .conv2d_same import Conv2dSame, conv2d_same -from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct -from .create_act import create_act_layer, get_act_layer, get_act_fn -from .create_attn import get_attn, create_attn -from .create_conv2d import create_conv2d -from .create_norm import get_norm_layer, create_norm_layer -from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer -from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path -from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn -from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ +from timm.layers.conv2d_same import Conv2dSame, conv2d_same +from timm.layers.conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct +from timm.layers.create_act import create_act_layer, get_act_layer, get_act_fn +from timm.layers.create_attn import get_attn, create_attn +from timm.layers.create_conv2d import create_conv2d +from timm.layers.create_norm import get_norm_layer, create_norm_layer +from timm.layers.create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer +from timm.layers.drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from timm.layers.eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from timm.layers.evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a -from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm -from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d -from .gather_excite import GatherExcite -from .global_context import GlobalContext -from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple -from .inplace_abn import InplaceAbn -from .linear import Linear -from .mixed_conv2d import MixedConv2d -from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp -from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d -from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm -from .padding import get_padding, get_same_padding, pad_same -from .patch_embed import PatchEmbed -from .pool2d_same import AvgPool2dSame, create_pool2d -from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite -from .selective_kernel import SelectiveKernel -from .separable_conv import SeparableConv2d, SeparableConvNormAct -from .space_to_depth import SpaceToDepthModule -from .split_attn import SplitAttn -from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model -from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame -from .test_time_pool import TestTimePoolHead, apply_test_time_pool -from .trace_utils import _assert, _float_to_int -from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ +from timm.layers.fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm +from timm.layers.filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d +from timm.layers.gather_excite import GatherExcite +from timm.layers.global_context import GlobalContext +from timm.layers.helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple +from timm.layers.inplace_abn import InplaceAbn +from timm.layers.linear import Linear +from timm.layers.mixed_conv2d import MixedConv2d +from timm.layers.mlp import Mlp, GluMlp, GatedMlp, ConvMlp +from timm.layers.non_local_attn import NonLocalAttn, BatNonLocalAttn +from timm.layers.norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d +from timm.layers.norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm +from timm.layers.padding import get_padding, get_same_padding, pad_same +from timm.layers.patch_embed import PatchEmbed +from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d +from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from timm.layers.selective_kernel import SelectiveKernel +from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct +from timm.layers.space_to_depth import SpaceToDepthModule +from timm.layers.split_attn import SplitAttn +from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool +from timm.layers.trace_utils import _assert, _float_to_int +from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/levit.py b/timm/models/levit.py index cea9f0fc..8dc11309 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -23,8 +23,6 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Modified from # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Copyright 2020 Ross Wightman, Apache-2.0 License -import itertools -from copy import deepcopy from functools import partial from typing import Dict @@ -32,10 +30,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import to_ntuple, get_act_layer -from .vision_transformer import trunc_normal_ -from .registry import register_model +from timm.layers import to_ntuple, get_act_layer, trunc_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['LevitDistilled'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 3f315093..1e2666e5 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -45,17 +45,17 @@ from typing import Callable, Optional, Union, Tuple, List import torch from torch import nn -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq, named_apply -from .fx_features import register_notrace_function -from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm -from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d -from .layers import SelectAdaptivePool2d, create_pool2d -from .layers import to_2tuple, extend_tuple, make_divisible, _assert -from .pretrained import generate_default_cfgs -from .registry import register_model +from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm +from timm.layers import SelectAdaptivePool2d, create_pool2d +from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d +from timm.layers import to_2tuple, extend_tuple, make_divisible, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply, checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location __all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index a77e2eb7..a7825899 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -39,16 +39,18 @@ A thank you to paper authors for releasing code and weights. Hacked together by / Copyright 2021 Ross Wightman """ import math -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model + +__all__ = ['MixerBlock'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index bb72ccb8..cf4f268d 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -14,13 +14,14 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ +from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq -from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer -from .registry import register_model +from ._features import FeatureInfo, FeatureHooks +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['MobileNetV3', 'MobileNetV3Features'] diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index bd5479a7..3d2ae84a 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -14,18 +14,18 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022 # Copyright (C) 2020 Apple Inc. All Rights Reserved. # import math -from typing import Union, Callable, Dict, Tuple, Optional, Sequence +from typing import Callable, Tuple, Optional import torch -from torch import nn import torch.nn.functional as F +from torch import nn +from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups -from .fx_features import register_notrace_module -from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath from .vision_transformer import Block as TransformerBlock -from .helpers import build_model_with_cfg -from .registry import register_model __all__ = [] diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index c5aaa09e..5c0a6650 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -24,10 +24,12 @@ import torch.utils.checkpoint as checkpoint from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg -from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple -from .registry import register_model +from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model + +__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 50db1a3d..0b2178d6 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier -from .registry import register_model +from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['NASNetALarge'] diff --git a/timm/models/nest.py b/timm/models/nest.py index 8692a2b1..c9c6258c 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -25,12 +25,14 @@ import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ -from .layers import _assert -from .layers import create_conv2d, create_pool2d, to_ntuple -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert +from timm.layers import create_conv2d, create_pool2d, to_ntuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model + +__all__ = ['Nest'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 3a45410b..48f91b35 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -16,21 +16,23 @@ Status: Hacked together by / copyright Ross Wightman, 2021. """ -import math -from dataclasses import dataclass, field from collections import OrderedDict -from typing import Tuple, Optional +from dataclasses import dataclass from functools import partial +from typing import Tuple, Optional import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model -from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ +from timm.layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame, \ get_act_layer, get_act_fn, get_attn, make_divisible +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['NormFreeNet', 'NfCfg'] # model_registry will add each entrypoint fn to this def _dcfg(url='', **kwargs): diff --git a/timm/models/pit.py b/timm/models/pit.py index 0f571319..4f40e5e0 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -13,7 +13,6 @@ Modifications for timm by / Copyright 2020 Ross Wightman import math import re -from copy import deepcopy from functools import partial from typing import Tuple @@ -21,12 +20,15 @@ import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import trunc_normal_, to_2tuple -from .registry import register_model +from timm.layers import trunc_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._registry import register_model from .vision_transformer import Block +__all__ = ['PoolingVisionTransformer'] # model_registry will add each entrypoint fn to this + + def _cfg(url='', **kwargs): return { 'url': url, diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 81067845..7291c8fb 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -12,9 +12,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier -from .registry import register_model +from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['PNASNet5Large'] diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index 09359bc8..b4d2d18f 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -19,15 +19,15 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import copy import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['PoolFormer'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index dd3cf690..696a2506 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -24,9 +24,9 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ -from .registry import register_model +from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['PyramidVisionTransformerV2'] diff --git a/timm/models/registry.py b/timm/models/registry.py index 159ffb5f..58e2e1f4 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -1,210 +1,4 @@ -""" Model Registry -Hacked together by / Copyright 2020 Ross Wightman -""" +from ._registry import * -import fnmatch -import re -import sys -from collections import defaultdict, deque -from copy import deepcopy -from typing import List, Optional, Union, Tuple - -from .pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag - -__all__ = [ - 'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', - 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] - -_module_to_models = defaultdict(set) # dict of sets to check membership of model in module -_model_to_module = {} # mapping of model names to module names -_model_entrypoints = {} # mapping of model names to architecture entrypoint fns -_model_has_pretrained = set() # set of model names that have pretrained weight url present -_model_default_cfgs = dict() # central repo for model arch -> default cfg objects -_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs -_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names - - -def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]: - return split_model_name_tag(model_name)[0] - - -def register_model(fn): - # lookup containing module - mod = sys.modules[fn.__module__] - module_name_split = fn.__module__.split('.') - module_name = module_name_split[-1] if len(module_name_split) else '' - - # add model to __all__ in module - model_name = fn.__name__ - if hasattr(mod, '__all__'): - mod.__all__.append(model_name) - else: - mod.__all__ = [model_name] - - # add entries to registry dict/sets - _model_entrypoints[model_name] = fn - _model_to_module[model_name] = module_name - _module_to_models[module_name].add(model_name) - if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: - # this will catch all models that have entrypoint matching cfg key, but miss any aliasing - # entrypoints or non-matching combos - cfg = mod.default_cfgs[model_name] - if not isinstance(cfg, DefaultCfg): - # new style default cfg dataclass w/ multiple entries per model-arch - assert isinstance(cfg, dict) - # old style cfg dict per model-arch - cfg = PretrainedCfg(**cfg) - cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg}) - - for tag_idx, tag in enumerate(cfg.tags): - is_default = tag_idx == 0 - pretrained_cfg = cfg.cfgs[tag] - if is_default: - _model_pretrained_cfgs[model_name] = pretrained_cfg - if pretrained_cfg.has_weights: - # add tagless entry if it's default and has weights - _model_has_pretrained.add(model_name) - if tag: - model_name_tag = '.'.join([model_name, tag]) - _model_pretrained_cfgs[model_name_tag] = pretrained_cfg - if pretrained_cfg.has_weights: - # add model w/ tag if tag is valid - _model_has_pretrained.add(model_name_tag) - _model_with_tags[model_name].append(model_name_tag) - else: - _model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances) - - _model_default_cfgs[model_name] = cfg - - return fn - - -def _natural_key(string_): - return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] - - -def list_models( - filter: Union[str, List[str]] = '', - module: str = '', - pretrained=False, - exclude_filters: str = '', - name_matches_cfg: bool = False, - include_tags: Optional[bool] = None, -): - """ Return list of available model names, sorted alphabetically - - Args: - filter (str) - Wildcard filter string that works with fnmatch - module (str) - Limit model selection to a specific submodule (ie 'vision_transformer') - pretrained (bool) - Include only models with valid pretrained weights if True - exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter - name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) - include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults - set to True when pretrained=True else False (default: None) - Example: - model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' - model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module - """ - if include_tags is None: - # FIXME should this be default behaviour? or default to include_tags=True? - include_tags = pretrained - - if module: - all_models = list(_module_to_models[module]) - else: - all_models = _model_entrypoints.keys() - - if include_tags: - # expand model names to include names w/ pretrained tags - models_with_tags = [] - for m in all_models: - models_with_tags.extend(_model_with_tags[m]) - all_models = models_with_tags - - if filter: - models = [] - include_filters = filter if isinstance(filter, (tuple, list)) else [filter] - for f in include_filters: - include_models = fnmatch.filter(all_models, f) # include these models - if len(include_models): - models = set(models).union(include_models) - else: - models = all_models - - if exclude_filters: - if not isinstance(exclude_filters, (tuple, list)): - exclude_filters = [exclude_filters] - for xf in exclude_filters: - exclude_models = fnmatch.filter(models, xf) # exclude these models - if len(exclude_models): - models = set(models).difference(exclude_models) - - if pretrained: - models = _model_has_pretrained.intersection(models) - - if name_matches_cfg: - models = set(_model_pretrained_cfgs).intersection(models) - - return list(sorted(models, key=_natural_key)) - - -def list_pretrained( - filter: Union[str, List[str]] = '', - exclude_filters: str = '', -): - return list_models( - filter=filter, - pretrained=True, - exclude_filters=exclude_filters, - include_tags=True, - ) - - -def is_model(model_name): - """ Check if a model name exists - """ - arch_name = get_arch_name(model_name) - return arch_name in _model_entrypoints - - -def model_entrypoint(model_name): - """Fetch a model entrypoint for specified model name - """ - arch_name = get_arch_name(model_name) - return _model_entrypoints[arch_name] - - -def list_modules(): - """ Return list of module names that contain models / model entrypoints - """ - modules = _module_to_models.keys() - return list(sorted(modules)) - - -def is_model_in_modules(model_name, module_names): - """Check if a model exists within a subset of modules - Args: - model_name (str) - name of model to check - module_names (tuple, list, set) - names of modules to search in - """ - arch_name = get_arch_name(model_name) - assert isinstance(module_names, (tuple, list, set)) - return any(arch_name in _module_to_models[n] for n in module_names) - - -def is_model_pretrained(model_name): - return model_name in _model_has_pretrained - - -def get_pretrained_cfg(model_name): - if model_name in _model_pretrained_cfgs: - return deepcopy(_model_pretrained_cfgs[model_name]) - raise RuntimeError(f'No pretrained config exists for model {model_name}.') - - -def get_pretrained_cfg_value(model_name, cfg_key): - """ Get a specific model default_cfg value by key. None if key doesn't exist. - """ - if model_name in _model_pretrained_cfgs: - return getattr(_model_pretrained_cfgs[model_name], cfg_key, None) - raise RuntimeError(f'No pretrained config exist for model {model_name}.') \ No newline at end of file +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 0ad7c826..e1cc821b 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -23,10 +23,13 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct -from .layers import get_act_layer, get_norm_act_layer, create_conv2d -from .registry import register_model +from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct +from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model + +__all__ = ['RegNet', 'RegNetCfg'] # model_registry will add each entrypoint fn to this @dataclass diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 6c2fd1bf..4724df2a 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -8,8 +8,8 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .registry import register_model +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet __all__ = [] diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 735b91a2..3b001c7b 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -6,13 +6,12 @@ Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang198 Modified for torchscript compat, and consistency with timm by Ross Wightman """ -import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SplitAttn -from .registry import register_model +from timm.layers import SplitAttn +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet diff --git a/timm/models/resnet.py b/timm/models/resnet.py index d0d98894..50849017 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -15,9 +15,11 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier -from .registry import register_model +from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \ + create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, model_entrypoint __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this @@ -675,6 +677,11 @@ class ResNet(nn.Module): self.init_weights(zero_init_last=zero_init_last) + @staticmethod + def from_pretrained(model_name: str, load_weights=True, **kwargs) -> 'ResNet': + entry_fn = model_entrypoint(model_name, 'resnet') + return entry_fn(pretrained=not load_weights, **kwargs) + @torch.jit.ignore def init_weights(self, zero_init_last=True): for n, m in self.named_modules(): @@ -822,7 +829,7 @@ def resnet50(pretrained=False, **kwargs): @register_model -def resnet50d(pretrained=False, **kwargs): +def resnet50d(pretrained=False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model. """ model_args = dict( diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index b21ef7f5..f8c4298b 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -30,16 +30,19 @@ Original copyright of Google code below, modifications by Ross Wightman, Copyrig # limitations under the License. from collections import OrderedDict # pylint: disable=g-importing-member +from functools import partial import torch import torch.nn as nn -from functools import partial from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq -from .registry import register_model -from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, EvoNorm2dS1, FilterResponseNormTlu2d,\ +from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \ ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv +from ._registry import register_model + +__all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 33e97222..51e8cdc2 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -10,16 +10,20 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe Copyright 2020 Ross Wightman """ -import torch -import torch.nn as nn from functools import partial from math import ceil +import torch +import torch.nn as nn + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule -from .registry import register_model -from .efficientnet_builder import efficientnet_init_weights +from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule +from ._builder import build_model_with_cfg +from ._efficientnet_builder import efficientnet_init_weights +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['ReXNetV1'] # model_registry will add each entrypoint fn to this def _cfg(url=''): diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 1a9ac929..4d40c49a 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -16,9 +16,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/senet.py b/timm/models/senet.py index a9e23ff1..d36e9854 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -19,9 +19,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['SENet'] diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index b1ae92a4..f3f758b9 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -6,7 +6,6 @@ Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2 # Copyright (c) 2022. Yuki Tatsunami # Licensed under the Apache License, Version 2.0 (the "License"); - import math from functools import partial from typing import Tuple @@ -15,9 +14,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT -from .helpers import build_model_with_cfg, named_apply -from .layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed -from .registry import register_model +from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed +from ._builder import build_model_with_cfg +from ._manipulate import named_apply +from ._registry import register_model + +__all__ = ['Sequencer2D'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/sknet.py b/timm/models/sknet.py index fb9f063a..5a29b9a4 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -13,9 +13,9 @@ import math from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SelectiveKernel, ConvNormAct, ConvNormActAa, create_attn -from .registry import register_model +from timm.layers import SelectiveKernel, ConvNormAct, create_attn +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index f2305fb2..5df06d4d 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -17,19 +17,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # -------------------------------------------------------- import logging import math -from functools import partial from typing import Optional import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit +__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 0c9db3dd..efaaa9e9 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -21,10 +21,12 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model + +__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index d143c14c..cf10b39c 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -29,7 +29,6 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W # -------------------------------------------------------- import logging import math -from copy import deepcopy from typing import Tuple, Optional, List, Union, Any, Type import torch @@ -38,11 +37,13 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import DropPath, Mlp, to_2tuple, _assert -from .registry import register_model +from timm.layers import DropPath, Mlp, to_2tuple, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply +from ._registry import register_model +__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 5b72b196..50088baf 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -7,17 +7,18 @@ The official mindspore code is released and available at https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT """ import math + import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import build_model_with_cfg -from timm.models.layers import Mlp, DropPath, trunc_normal_ -from timm.models.layers.helpers import to_2tuple -from timm.models.layers import _assert -from timm.models.registry import register_model -from timm.models.vision_transformer import resize_pos_embed +from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple +from ._builder import build_model_with_cfg +from ._registry import register_model +from .vision_transformer import resize_pos_embed + +__all__ = ['TNT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 2469acd2..83cb0576 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -10,11 +10,11 @@ from collections import OrderedDict import torch import torch.nn as nn -from .helpers import build_model_with_cfg -from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule -from .registry import register_model +from timm.layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule +from ._builder import build_model_with_cfg +from ._registry import register_model -__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] +__all__ = ['TResNet'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/twins.py b/timm/models/twins.py index 0626db37..41944c36 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -12,20 +12,21 @@ Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/li # Written by Xinjie Li, Xiangxiang Chu # -------------------------------------------------------- import math -from copy import deepcopy -from typing import Optional, Tuple +from functools import partial +from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F -from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ -from .fx_features import register_notrace_module -from .registry import register_model +from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .vision_transformer import Attention -from .helpers import build_model_with_cfg + +__all__ = ['Twins'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/vgg.py b/timm/models/vgg.py index caf96517..abe9f8d5 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -5,21 +5,19 @@ timm functionality. Copyright 2021 Ross Wightman """ +from typing import Union, List, Dict, Any, cast + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .fx_features import register_notrace_module -from .layers import ClassifierHead -from .registry import register_model - -__all__ = [ - 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', - 'vgg19_bn', 'vgg19', -] +from timm.layers import ClassifierHead +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model + +__all__ = ['VGG'] def _cfg(url='', **kwargs): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 254a0748..e15ae4a5 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -6,17 +6,15 @@ From original at https://github.com/danczs/Visformer Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman """ -from copy import deepcopy import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier -from .registry import register_model - +from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['Visformer'] diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 4effbed6..5b93628f 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -19,10 +19,10 @@ for some einops/einsum fun Hacked together by / Copyright 2020, Ross Wightman """ -import math import logging -from functools import partial +import math from collections import OrderedDict +from functools import partial from typing import Optional import torch @@ -30,12 +30,17 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\ +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ -from .pretrained import generate_default_cfgs -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv +from ._pretrained import generate_default_cfgs +from ._registry import register_model + + +__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this + _logger = logging.getLogger(__name__) @@ -933,6 +938,25 @@ default_cfgs = generate_default_cfgs({ 'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''), + + # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain + # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip + 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_large_patch14_196.in22k_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), }) @@ -1354,3 +1378,21 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs): patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def eva_large_patch14_196(pretrained=False, **kwargs): + """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" + model_kwargs = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) + model = _create_vision_transformer('eva_large_patch14_196', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def eva_large_patch14_336(pretrained=False, **kwargs): + """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" + model_kwargs = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) + model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 5e5113d7..cfdd0a0e 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -13,19 +13,18 @@ They were moved here to keep file sizes sane. Hacked together by / Copyright 2020, Ross Wightman """ -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import StdConv2dSame, StdConv2d, to_2tuple -from .pretrained import generate_default_cfgs +from timm.layers import StdConv2dSame, StdConv2d, to_2tuple +from ._pretrained import generate_default_cfgs +from ._registry import register_model from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem -from .registry import register_model -from timm.models.vision_transformer import _create_vision_transformer +from .vision_transformer import _create_vision_transformer def _cfg(url='', **kwargs): diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 52b3ce45..1a7c2f40 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -4,11 +4,9 @@ NOTE: these models are experimental / WIP, expect changes Hacked together by / Copyright 2022, Ross Wightman """ -import math import logging +import math from functools import partial -from collections import OrderedDict -from dataclasses import dataclass from typing import Optional, Tuple import torch @@ -16,10 +14,12 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple -from .registry import register_model +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/volo.py b/timm/models/volo.py index 735453c8..1117995a 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -20,17 +20,19 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman # See the License for the specific language governing permissions and # limitations under the License. import math -import numpy as np +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ -from timm.models.registry import register_model -from timm.models.helpers import build_model_with_cfg +from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['VOLO'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 39d37195..bf0e4f89 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -15,13 +15,15 @@ from typing import List import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath,\ +from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \ create_attn, create_norm_act_layer, get_norm_act_layer +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['VovNet'] # model_registry will add each entrypoint fn to this # model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 & diff --git a/timm/models/xception.py b/timm/models/xception.py index 99d02c46..99e74b46 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -25,9 +25,9 @@ import torch.jit import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['Xception'] diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 6bbce5e6..e3348e64 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -11,10 +11,11 @@ import torch import torch.nn as nn from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer -from .layers.helpers import to_3tuple -from .registry import register_model +from timm.layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer +from timm.layers.helpers import to_3tuple +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['XceptionAligned'] diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 6802fc84..57c11183 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -19,12 +19,14 @@ import torch.nn as nn from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .vision_transformer import _cfg, Mlp -from .registry import register_model -from .layers import DropPath, trunc_normal_, to_2tuple +from timm.layers import DropPath, trunc_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .cait import ClassAttn -from .fx_features import register_notrace_module +from .vision_transformer import Mlp + +__all__ = ['XCiT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 02f0e250..8613a62c 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn import torch.optim as optim -from timm.models.helpers import group_parameters +from timm.models import group_parameters from .adabelief import AdaBelief from .adafactor import Adafactor diff --git a/timm/version.py b/timm/version.py index 0f19999f..0716d38a 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.8.0dev0' +__version__ = '0.8.1dev0' diff --git a/train.py b/train.py index d40ff04b..e51d7c90 100755 --- a/train.py +++ b/train.py @@ -31,10 +31,9 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm import utils from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \ - LabelSmoothingCrossEntropy -from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ - convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm +from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm +from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy +from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler @@ -82,7 +81,7 @@ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset parameters group = parser.add_argument_group('Dataset parameters') -# Keep this argument outside of the dataset group because it is positional. +# Keep this argument outside the dataset group because it is positional. parser.add_argument('data', nargs='?', metavar='DIR', const=None, help='path to dataset (positional is *deprecated*, use --data-dir)') parser.add_argument('--data-dir', metavar='DIR', @@ -970,16 +969,16 @@ def validate( with amp_autocast(): output = model(input) - if isinstance(output, (tuple, list)): - output = output[0] + if isinstance(output, (tuple, list)): + output = output[0] - # augmentation reduction - reduce_factor = args.tta - if reduce_factor > 1: - output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) - target = target[0:target.size(0):reduce_factor] + # augmentation reduction + reduce_factor = args.tta + if reduce_factor > 1: + output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) + target = target[0:target.size(0):reduce_factor] - loss = loss_fn(output, target) + loss = loss_fn(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) if args.distributed: diff --git a/validate.py b/validate.py index 6b8222b9..4669fbac 100755 --- a/validate.py +++ b/validate.py @@ -8,22 +8,24 @@ canonical PyTorch, standard Python style, and good performance. Repurpose as you Hacked together by Ross Wightman (https://github.com/rwightman) """ import argparse -import os import csv import glob import json -import time import logging -import torch -import torch.nn as nn -import torch.nn.parallel +import os +import time from collections import OrderedDict from contextlib import suppress from functools import partial -from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm +import torch +import torch.nn as nn +import torch.nn.parallel + from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet -from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\ +from timm.layers import apply_test_time_pool, set_fast_norm +from timm.models import create_model, load_checkpoint, is_model, list_models +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ decay_batch_step, check_batch_size_retry try: @@ -294,9 +296,9 @@ def validate(args): with amp_autocast(): output = model(input) - if valid_labels is not None: - output = output[:, valid_labels] - loss = criterion(output, target) + if valid_labels is not None: + output = output[:, valid_labels] + loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output)