diff --git a/avg_checkpoints.py b/avg_checkpoints.py index ea8bbe84..83af5bbd 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -16,7 +16,7 @@ import argparse import os import glob import hashlib -from timm.models.helpers import load_state_dict +from timm.models import load_state_dict parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser.add_argument('--input', default='', type=str, metavar='PATH', diff --git a/benchmark.py b/benchmark.py index 9adeb465..58435ff8 100755 --- a/benchmark.py +++ b/benchmark.py @@ -19,7 +19,8 @@ import torch.nn as nn import torch.nn.parallel from timm.data import resolve_data_config -from timm.models import create_model, is_model, list_models, set_fast_norm +from timm.layers import set_fast_norm +from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry diff --git a/clean_checkpoint.py b/clean_checkpoint.py index 8ec892b2..17c270db 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -13,7 +13,7 @@ import os import hashlib import shutil from collections import OrderedDict -from timm.models.helpers import load_state_dict +from timm.models import load_state_dict parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', diff --git a/hubconf.py b/hubconf.py index 70fed79a..6b2061ea 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,4 +1,3 @@ dependencies = ['torch'] -from timm.models import registry - -globals().update(registry._model_entrypoints) +import timm +globals().update(timm.models._registry._model_entrypoints) diff --git a/inference.py b/inference.py index bc794840..1509b323 100755 --- a/inference.py +++ b/inference.py @@ -5,11 +5,11 @@ An example inference script that outputs top-k class ids for images in a folder Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ -import os -import time import argparse import json import logging +import os +import time from contextlib import suppress from functools import partial @@ -17,12 +17,11 @@ import numpy as np import pandas as pd import torch -from timm.models import create_model, apply_test_time_pool, load_checkpoint from timm.data import create_dataset, create_loader, resolve_data_config +from timm.layers import apply_test_time_pool +from timm.models import create_model from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser - - try: from apex import amp has_apex = True diff --git a/tests/test_layers.py b/tests/test_layers.py index 508a6aae..da061870 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,10 +1,7 @@ -import pytest import torch import torch.nn as nn -import platform -import os -from timm.models.layers import create_act_layer, get_act_layer, set_layer_config +from timm.layers import create_act_layer, set_layer_config class MLP(nn.Module): diff --git a/tests/test_models.py b/tests/test_models.py index 87d75cbd..2392a190 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,7 +14,7 @@ except ImportError: import timm from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value -from timm.models.fx_features import _leaf_modules, _autowrap_functions +from timm.models._features_fx import _leaf_modules, _autowrap_functions if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests diff --git a/timm/__init__.py b/timm/__init__.py index faf34dbc..3d38cdb9 100644 --- a/timm/__init__.py +++ b/timm/__init__.py @@ -1,4 +1,4 @@ from .version import __version__ +from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \ - is_scriptable, is_exportable, set_scriptable, set_exportable, \ is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/data/readers/class_map.py b/timm/data/readers/class_map.py index 6cf3f57e..885be6e2 100644 --- a/timm/data/readers/class_map.py +++ b/timm/data/readers/class_map.py @@ -1,6 +1,7 @@ import os import pickle + def load_class_map(map_or_filename, root=''): if isinstance(map_or_filename, dict): assert dict, 'class_map dict must be non-empty' @@ -14,7 +15,7 @@ def load_class_map(map_or_filename, root=''): with open(class_map_path) as f: class_to_idx = {v.strip(): k for k, v in enumerate(f)} elif class_map_ext == '.pkl': - with open(class_map_path,'rb') as f: + with open(class_map_path, 'rb') as f: class_to_idx = pickle.load(f) else: assert False, f'Unsupported class map file extension ({class_map_ext}).' diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py new file mode 100644 index 00000000..21c641b6 --- /dev/null +++ b/timm/layers/__init__.py @@ -0,0 +1,44 @@ +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .blur_pool import BlurPool2d +from .classifier import ClassifierHead, create_classifier +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ + set_layer_config +from .conv2d_same import Conv2dSame, conv2d_same +from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct +from .create_act import create_act_layer, get_act_layer, get_act_fn +from .create_attn import get_attn, create_attn +from .create_conv2d import create_conv2d +from .create_norm import get_norm_layer, create_norm_layer +from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ + EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm +from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple +from .inplace_abn import InplaceAbn +from .linear import Linear +from .mixed_conv2d import MixedConv2d +from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d +from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm +from .padding import get_padding, get_same_padding, pad_same +from .patch_embed import PatchEmbed +from .pool2d_same import AvgPool2dSame, create_pool2d +from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from .selective_kernel import SelectiveKernel +from .separable_conv import SeparableConv2d, SeparableConvNormAct +from .space_to_depth import SpaceToDepthModule +from .split_attn import SplitAttn +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .trace_utils import _assert, _float_to_int +from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/activations.py b/timm/layers/activations.py similarity index 100% rename from timm/models/layers/activations.py rename to timm/layers/activations.py diff --git a/timm/models/layers/activations_jit.py b/timm/layers/activations_jit.py similarity index 100% rename from timm/models/layers/activations_jit.py rename to timm/layers/activations_jit.py diff --git a/timm/models/layers/activations_me.py b/timm/layers/activations_me.py similarity index 100% rename from timm/models/layers/activations_me.py rename to timm/layers/activations_me.py diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/layers/adaptive_avgmax_pool.py similarity index 100% rename from timm/models/layers/adaptive_avgmax_pool.py rename to timm/layers/adaptive_avgmax_pool.py diff --git a/timm/models/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py similarity index 100% rename from timm/models/layers/attention_pool2d.py rename to timm/layers/attention_pool2d.py diff --git a/timm/models/layers/blur_pool.py b/timm/layers/blur_pool.py similarity index 100% rename from timm/models/layers/blur_pool.py rename to timm/layers/blur_pool.py diff --git a/timm/models/layers/bottleneck_attn.py b/timm/layers/bottleneck_attn.py similarity index 100% rename from timm/models/layers/bottleneck_attn.py rename to timm/layers/bottleneck_attn.py diff --git a/timm/models/layers/cbam.py b/timm/layers/cbam.py similarity index 100% rename from timm/models/layers/cbam.py rename to timm/layers/cbam.py diff --git a/timm/models/layers/classifier.py b/timm/layers/classifier.py similarity index 100% rename from timm/models/layers/classifier.py rename to timm/layers/classifier.py diff --git a/timm/models/layers/cond_conv2d.py b/timm/layers/cond_conv2d.py similarity index 100% rename from timm/models/layers/cond_conv2d.py rename to timm/layers/cond_conv2d.py diff --git a/timm/models/layers/config.py b/timm/layers/config.py similarity index 100% rename from timm/models/layers/config.py rename to timm/layers/config.py diff --git a/timm/models/layers/conv2d_same.py b/timm/layers/conv2d_same.py similarity index 100% rename from timm/models/layers/conv2d_same.py rename to timm/layers/conv2d_same.py diff --git a/timm/models/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py similarity index 100% rename from timm/models/layers/conv_bn_act.py rename to timm/layers/conv_bn_act.py diff --git a/timm/models/layers/create_act.py b/timm/layers/create_act.py similarity index 100% rename from timm/models/layers/create_act.py rename to timm/layers/create_act.py diff --git a/timm/models/layers/create_attn.py b/timm/layers/create_attn.py similarity index 100% rename from timm/models/layers/create_attn.py rename to timm/layers/create_attn.py diff --git a/timm/models/layers/create_conv2d.py b/timm/layers/create_conv2d.py similarity index 100% rename from timm/models/layers/create_conv2d.py rename to timm/layers/create_conv2d.py diff --git a/timm/models/layers/create_norm.py b/timm/layers/create_norm.py similarity index 100% rename from timm/models/layers/create_norm.py rename to timm/layers/create_norm.py diff --git a/timm/models/layers/create_norm_act.py b/timm/layers/create_norm_act.py similarity index 100% rename from timm/models/layers/create_norm_act.py rename to timm/layers/create_norm_act.py diff --git a/timm/models/layers/drop.py b/timm/layers/drop.py similarity index 100% rename from timm/models/layers/drop.py rename to timm/layers/drop.py diff --git a/timm/models/layers/eca.py b/timm/layers/eca.py similarity index 100% rename from timm/models/layers/eca.py rename to timm/layers/eca.py diff --git a/timm/models/layers/evo_norm.py b/timm/layers/evo_norm.py similarity index 100% rename from timm/models/layers/evo_norm.py rename to timm/layers/evo_norm.py diff --git a/timm/models/layers/fast_norm.py b/timm/layers/fast_norm.py similarity index 100% rename from timm/models/layers/fast_norm.py rename to timm/layers/fast_norm.py diff --git a/timm/models/layers/filter_response_norm.py b/timm/layers/filter_response_norm.py similarity index 100% rename from timm/models/layers/filter_response_norm.py rename to timm/layers/filter_response_norm.py diff --git a/timm/models/layers/gather_excite.py b/timm/layers/gather_excite.py similarity index 100% rename from timm/models/layers/gather_excite.py rename to timm/layers/gather_excite.py diff --git a/timm/models/layers/global_context.py b/timm/layers/global_context.py similarity index 100% rename from timm/models/layers/global_context.py rename to timm/layers/global_context.py diff --git a/timm/models/layers/halo_attn.py b/timm/layers/halo_attn.py similarity index 100% rename from timm/models/layers/halo_attn.py rename to timm/layers/halo_attn.py diff --git a/timm/models/layers/helpers.py b/timm/layers/helpers.py similarity index 100% rename from timm/models/layers/helpers.py rename to timm/layers/helpers.py diff --git a/timm/models/layers/inplace_abn.py b/timm/layers/inplace_abn.py similarity index 100% rename from timm/models/layers/inplace_abn.py rename to timm/layers/inplace_abn.py diff --git a/timm/models/layers/lambda_layer.py b/timm/layers/lambda_layer.py similarity index 100% rename from timm/models/layers/lambda_layer.py rename to timm/layers/lambda_layer.py diff --git a/timm/models/layers/linear.py b/timm/layers/linear.py similarity index 100% rename from timm/models/layers/linear.py rename to timm/layers/linear.py diff --git a/timm/models/layers/median_pool.py b/timm/layers/median_pool.py similarity index 100% rename from timm/models/layers/median_pool.py rename to timm/layers/median_pool.py diff --git a/timm/models/layers/mixed_conv2d.py b/timm/layers/mixed_conv2d.py similarity index 100% rename from timm/models/layers/mixed_conv2d.py rename to timm/layers/mixed_conv2d.py diff --git a/timm/models/layers/ml_decoder.py b/timm/layers/ml_decoder.py similarity index 100% rename from timm/models/layers/ml_decoder.py rename to timm/layers/ml_decoder.py diff --git a/timm/models/layers/mlp.py b/timm/layers/mlp.py similarity index 100% rename from timm/models/layers/mlp.py rename to timm/layers/mlp.py diff --git a/timm/models/layers/non_local_attn.py b/timm/layers/non_local_attn.py similarity index 100% rename from timm/models/layers/non_local_attn.py rename to timm/layers/non_local_attn.py diff --git a/timm/models/layers/norm.py b/timm/layers/norm.py similarity index 100% rename from timm/models/layers/norm.py rename to timm/layers/norm.py diff --git a/timm/models/layers/norm_act.py b/timm/layers/norm_act.py similarity index 100% rename from timm/models/layers/norm_act.py rename to timm/layers/norm_act.py diff --git a/timm/models/layers/padding.py b/timm/layers/padding.py similarity index 100% rename from timm/models/layers/padding.py rename to timm/layers/padding.py diff --git a/timm/models/layers/patch_embed.py b/timm/layers/patch_embed.py similarity index 100% rename from timm/models/layers/patch_embed.py rename to timm/layers/patch_embed.py diff --git a/timm/models/layers/pool2d_same.py b/timm/layers/pool2d_same.py similarity index 100% rename from timm/models/layers/pool2d_same.py rename to timm/layers/pool2d_same.py diff --git a/timm/models/layers/pos_embed.py b/timm/layers/pos_embed.py similarity index 100% rename from timm/models/layers/pos_embed.py rename to timm/layers/pos_embed.py diff --git a/timm/models/layers/selective_kernel.py b/timm/layers/selective_kernel.py similarity index 100% rename from timm/models/layers/selective_kernel.py rename to timm/layers/selective_kernel.py diff --git a/timm/models/layers/separable_conv.py b/timm/layers/separable_conv.py similarity index 100% rename from timm/models/layers/separable_conv.py rename to timm/layers/separable_conv.py diff --git a/timm/models/layers/space_to_depth.py b/timm/layers/space_to_depth.py similarity index 100% rename from timm/models/layers/space_to_depth.py rename to timm/layers/space_to_depth.py diff --git a/timm/models/layers/split_attn.py b/timm/layers/split_attn.py similarity index 100% rename from timm/models/layers/split_attn.py rename to timm/layers/split_attn.py diff --git a/timm/models/layers/split_batchnorm.py b/timm/layers/split_batchnorm.py similarity index 100% rename from timm/models/layers/split_batchnorm.py rename to timm/layers/split_batchnorm.py diff --git a/timm/models/layers/squeeze_excite.py b/timm/layers/squeeze_excite.py similarity index 100% rename from timm/models/layers/squeeze_excite.py rename to timm/layers/squeeze_excite.py diff --git a/timm/models/layers/std_conv.py b/timm/layers/std_conv.py similarity index 100% rename from timm/models/layers/std_conv.py rename to timm/layers/std_conv.py diff --git a/timm/models/layers/test_time_pool.py b/timm/layers/test_time_pool.py similarity index 100% rename from timm/models/layers/test_time_pool.py rename to timm/layers/test_time_pool.py diff --git a/timm/models/layers/trace_utils.py b/timm/layers/trace_utils.py similarity index 100% rename from timm/models/layers/trace_utils.py rename to timm/layers/trace_utils.py diff --git a/timm/models/layers/weight_init.py b/timm/layers/weight_init.py similarity index 100% rename from timm/models/layers/weight_init.py rename to timm/layers/weight_init.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 301186dd..5ecc8915 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -64,12 +64,18 @@ from .xception import * from .xception_aligned import * from .xcit import * -from .factory import create_model, parse_model_name, safe_model_name -from .helpers import load_checkpoint, resume_checkpoint, model_parameters -from .layers import TestTimePoolHead, apply_test_time_pool -from .layers import convert_splitbn_model, convert_sync_batchnorm -from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit -from .layers import set_fast_norm -from .pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag -from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\ +from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \ + set_pretrained_download_progress, set_pretrained_check_hash +from ._factory import create_model, parse_model_name, safe_model_name +from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet +from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \ + register_notrace_module, register_notrace_function +from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint +from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub +from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \ + group_modules, group_parameters, checkpoint_seq, adapt_input_conv +from ._pretrained import PretrainedCfg, DefaultCfg, \ + filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag +from ._prune import adapt_model_from_string +from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \ is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/models/_builder.py b/timm/models/_builder.py new file mode 100644 index 00000000..f634650e --- /dev/null +++ b/timm/models/_builder.py @@ -0,0 +1,399 @@ +import dataclasses +import logging +from copy import deepcopy +from typing import Optional, Dict, Callable, Any, Tuple + +from torch import nn as nn +from torch.hub import load_state_dict_from_url + +from timm.models._features import FeatureListNet, FeatureHookNet +from timm.models._features_fx import FeatureGraphNet +from timm.models._helpers import load_state_dict +from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf +from timm.models._manipulate import adapt_input_conv +from timm.models._pretrained import PretrainedCfg +from timm.models._prune import adapt_model_from_file +from timm.models._registry import get_pretrained_cfg + +_logger = logging.getLogger(__name__) + +# Global variables for rarely used pretrained checkpoint download progress and hash check. +# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle. +_DOWNLOAD_PROGRESS = False +_CHECK_HASH = False + + +__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained', + 'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg'] + + +def _resolve_pretrained_source(pretrained_cfg): + cfg_source = pretrained_cfg.get('source', '') + pretrained_url = pretrained_cfg.get('url', None) + pretrained_file = pretrained_cfg.get('file', None) + hf_hub_id = pretrained_cfg.get('hf_hub_id', None) + # resolve where to load pretrained weights from + load_from = '' + pretrained_loc = '' + if cfg_source == 'hf-hub' and has_hf_hub(necessary=True): + # hf-hub specified as source via model identifier + load_from = 'hf-hub' + assert hf_hub_id + pretrained_loc = hf_hub_id + else: + # default source == timm or unspecified + if pretrained_file: + load_from = 'file' + pretrained_loc = pretrained_file + elif pretrained_url: + load_from = 'url' + pretrained_loc = pretrained_url + elif hf_hub_id and has_hf_hub(necessary=True): + # hf-hub available as alternate weight source in default_cfg + load_from = 'hf-hub' + pretrained_loc = hf_hub_id + if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None): + # if a filename override is set, return tuple for location w/ (hub_id, filename) + pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] + return load_from, pretrained_loc + + +def set_pretrained_download_progress(enable=True): + """ Set download progress for pretrained weights on/off (globally). """ + global _DOWNLOAD_PROGRESS + _DOWNLOAD_PROGRESS = enable + + +def set_pretrained_check_hash(enable=True): + """ Set hash checking for pretrained weights on/off (globally). """ + global _CHECK_HASH + _CHECK_HASH = enable + + +def load_custom_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + load_fn: Optional[Callable] = None, +): + r"""Loads a custom (read non .pth) weight file + + Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls + a passed in custom load fun, or the `load_pretrained` model member fn. + + If the object is already present in `model_dir`, it's deserialized and returned. + The default value of `model_dir` is ``/checkpoints`` where + `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + model: The instantiated model to load weights into + pretrained_cfg (dict): Default pretrained model cfg + load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named + 'laod_pretrained' on the model will be called if it exists + """ + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) + if not pretrained_cfg: + _logger.warning("Invalid pretrained config, cannot load weights.") + return + + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if not load_from: + _logger.warning("No pretrained weights exist for this model. Using random initialization.") + return + if load_from == 'hf-hub': # FIXME + _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") + elif load_from == 'url': + pretrained_loc = download_cached_file( + pretrained_loc, + check_hash=_CHECK_HASH, + progress=_DOWNLOAD_PROGRESS + ) + + if load_fn is not None: + load_fn(model, pretrained_loc) + elif hasattr(model, 'load_pretrained'): + model.load_pretrained(pretrained_loc) + else: + _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") + + +def load_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + num_classes: int = 1000, + in_chans: int = 3, + filter_fn: Optional[Callable] = None, + strict: bool = True, +): + """ Load pretrained checkpoint + + Args: + model (nn.Module) : PyTorch model module + pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset + num_classes (int): num_classes for target model + in_chans (int): in_chans for target model + filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) + strict (bool): strict load of checkpoint + + """ + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) + if not pretrained_cfg: + _logger.warning("Invalid pretrained config, cannot load weights.") + return + + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if load_from == 'file': + _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') + state_dict = load_state_dict(pretrained_loc) + elif load_from == 'url': + _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') + state_dict = load_state_dict_from_url( + pretrained_loc, + map_location='cpu', + progress=_DOWNLOAD_PROGRESS, + check_hash=_CHECK_HASH, + ) + elif load_from == 'hf-hub': + _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') + if isinstance(pretrained_loc, (list, tuple)): + state_dict = load_state_dict_from_hf(*pretrained_loc) + else: + state_dict = load_state_dict_from_hf(pretrained_loc) + else: + _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") + return + + if filter_fn is not None: + # for backwards compat with filter fn that take one arg, try one first, the two + try: + state_dict = filter_fn(state_dict) + except TypeError: + state_dict = filter_fn(state_dict, model) + + input_convs = pretrained_cfg.get('first_conv', None) + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + try: + state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) + _logger.info( + f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + _logger.warning( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') + + classifiers = pretrained_cfg.get('classifier', None) + label_offset = pretrained_cfg.get('label_offset', 0) + if classifiers is not None: + if isinstance(classifiers, str): + classifiers = (classifiers,) + if num_classes != pretrained_cfg['num_classes']: + for classifier_name in classifiers: + # completely discard fully connected if model num_classes doesn't match pretrained weights + state_dict.pop(classifier_name + '.weight', None) + state_dict.pop(classifier_name + '.bias', None) + strict = False + elif label_offset > 0: + for classifier_name in classifiers: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + + model.load_state_dict(state_dict, strict=strict) + + +def pretrained_cfg_for_features(pretrained_cfg): + pretrained_cfg = deepcopy(pretrained_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size? + for tr in to_remove: + pretrained_cfg.pop(tr, None) + return pretrained_cfg + + +def _filter_kwargs(kwargs, names): + if not kwargs or not names: + return + for n in names: + kwargs.pop(n, None) + + +def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter): + """ Update the default_cfg and kwargs before passing to model + + Args: + pretrained_cfg: input pretrained cfg (updated in-place) + kwargs: keyword args passed to model build fn (updated in-place) + kwargs_filter: keyword arg keys that must be removed before model __init__ + """ + # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if pretrained_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + + for n in default_kwarg_names: + # for legacy reasons, model __init__args uses img_size + in_chans as separate args while + # pretrained_cfg has one input_size=(C, H ,W) entry + if n == 'img_size': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[-2:]) + elif n == 'in_chans': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[0]) + else: + default_val = pretrained_cfg.get(n, None) + if default_val is not None: + kwargs.setdefault(n, pretrained_cfg[n]) + + # Filter keyword args for task specific model variants (some 'features only' models, etc.) + _filter_kwargs(kwargs, names=kwargs_filter) + + +def resolve_pretrained_cfg( + variant: str, + pretrained_cfg=None, + pretrained_cfg_overlay=None, +) -> PretrainedCfg: + model_with_tag = variant + pretrained_tag = None + if pretrained_cfg: + if isinstance(pretrained_cfg, dict): + # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg + pretrained_cfg = PretrainedCfg(**pretrained_cfg) + elif isinstance(pretrained_cfg, str): + pretrained_tag = pretrained_cfg + pretrained_cfg = None + + # fallback to looking up pretrained cfg in model registry by variant identifier + if not pretrained_cfg: + if pretrained_tag: + model_with_tag = '.'.join([variant, pretrained_tag]) + pretrained_cfg = get_pretrained_cfg(model_with_tag) + + if not pretrained_cfg: + _logger.warning( + f"No pretrained configuration specified for {model_with_tag} model. Using a default." + f" Please add a config to the model pretrained_cfg registry or pass explicitly.") + pretrained_cfg = PretrainedCfg() # instance with defaults + + pretrained_cfg_overlay = pretrained_cfg_overlay or {} + if not pretrained_cfg.architecture: + pretrained_cfg_overlay.setdefault('architecture', variant) + pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay) + + return pretrained_cfg + + +def build_model_with_cfg( + model_cls: Callable, + variant: str, + pretrained: bool, + pretrained_cfg: Optional[Dict] = None, + pretrained_cfg_overlay: Optional[Dict] = None, + model_cfg: Optional[Any] = None, + feature_cfg: Optional[Dict] = None, + pretrained_strict: bool = True, + pretrained_filter_fn: Optional[Callable] = None, + kwargs_filter: Optional[Tuple[str]] = None, + **kwargs, +): + """ Build model with specified default_cfg and optional model_cfg + + This helper fn aids in the construction of a model including: + * handling default_cfg and associated pretrained weight loading + * passing through optional model_cfg for models with config based arch spec + * features_only model adaptation + * pruning config / model adaptation + + Args: + model_cls (nn.Module): model class + variant (str): model variant name + pretrained (bool): load pretrained weights + pretrained_cfg (dict): model's pretrained weight/task config + model_cfg (Optional[Dict]): model's architecture config + feature_cfg (Optional[Dict]: feature extraction adapter config + pretrained_strict (bool): load pretrained weights strictly + pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights + kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model + **kwargs: model args passed through to model __init__ + """ + pruned = kwargs.pop('pruned', False) + features = False + feature_cfg = feature_cfg or {} + + # resolve and update model pretrained config and model kwargs + pretrained_cfg = resolve_pretrained_cfg( + variant, + pretrained_cfg=pretrained_cfg, + pretrained_cfg_overlay=pretrained_cfg_overlay + ) + + # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model + pretrained_cfg = pretrained_cfg.to_dict() + + _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) + + # Setup for feature extraction wrapper done at end of this fn + if kwargs.pop('features_only', False): + features = True + feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) + if 'out_indices' in kwargs: + feature_cfg['out_indices'] = kwargs.pop('out_indices') + + # Instantiate the model + if model_cfg is None: + model = model_cls(**kwargs) + else: + model = model_cls(cfg=model_cfg, **kwargs) + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat + + if pruned: + model = adapt_model_from_file(model, variant) + + # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) + if pretrained: + if pretrained_cfg.get('custom_load', False): + load_custom_pretrained( + model, + pretrained_cfg=pretrained_cfg, + ) + else: + load_pretrained( + model, + pretrained_cfg=pretrained_cfg, + num_classes=num_classes_pretrained, + in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, + strict=pretrained_strict, + ) + + # Wrap the model in a feature extraction module if enabled + if features: + feature_cls = FeatureListNet + if 'feature_cls' in feature_cfg: + feature_cls = feature_cfg.pop('feature_cls') + if isinstance(feature_cls, str): + feature_cls = feature_cls.lower() + if 'hook' in feature_cls: + feature_cls = FeatureHookNet + elif feature_cls == 'fx': + feature_cls = FeatureGraphNet + else: + assert False, f'Unknown feature class {feature_cls}' + model = feature_cls(model, **feature_cfg) + model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat + + return model diff --git a/timm/models/efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py similarity index 99% rename from timm/models/efficientnet_blocks.py rename to timm/models/_efficientnet_blocks.py index 34a31757..92b849e4 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -2,13 +2,12 @@ Hacked together by / Copyright 2019, Ross Wightman """ -import math import torch import torch.nn as nn from torch.nn import functional as F -from .layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer +from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer __all__ = [ 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] diff --git a/timm/models/efficientnet_builder.py b/timm/models/_efficientnet_builder.py similarity index 99% rename from timm/models/efficientnet_builder.py rename to timm/models/_efficientnet_builder.py index 67d15a86..e6cd05ae 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -14,8 +14,8 @@ from functools import partial import torch.nn as nn -from .efficientnet_blocks import * -from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible +from ._efficientnet_blocks import * +from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] diff --git a/timm/models/_factory.py b/timm/models/_factory.py new file mode 100644 index 00000000..a8092419 --- /dev/null +++ b/timm/models/_factory.py @@ -0,0 +1,103 @@ +import os +from typing import Any, Dict, Optional, Union +from urllib.parse import urlsplit + +from timm.layers import set_layer_config +from ._pretrained import PretrainedCfg, split_model_name_tag +from ._helpers import load_checkpoint +from ._hub import load_model_config_from_hf +from ._registry import is_model, model_entrypoint + + +__all__ = ['parse_model_name', 'safe_model_name', 'create_model'] + + +def parse_model_name(model_name): + if model_name.startswith('hf_hub'): + # NOTE for backwards compat, deprecate hf_hub use + model_name = model_name.replace('hf_hub', 'hf-hub') + parsed = urlsplit(model_name) + assert parsed.scheme in ('', 'timm', 'hf-hub') + if parsed.scheme == 'hf-hub': + # FIXME may use fragment as revision, currently `@` in URI path + return parsed.scheme, parsed.path + else: + model_name = os.path.split(parsed.path)[-1] + return 'timm', model_name + + +def safe_model_name(model_name, remove_source=True): + # return a filename / path safe model name + def make_safe(name): + return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') + if remove_source: + model_name = parse_model_name(model_name)[-1] + return make_safe(model_name) + + +def create_model( + model_name: str, + pretrained: bool = False, + pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, + pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, + checkpoint_path: str = '', + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + **kwargs, +): + """Create a model + + Lookup model's entrypoint function and pass relevant args to create a new model. + + **kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg() + and then the model class __init__(). kwargs values set to None are pruned before passing. + + Args: + model_name (str): name of model to instantiate + pretrained (bool): load pretrained ImageNet-1k weights if true + pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model + pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these + checkpoint_path (str): path of checkpoint to load _after_ the model is initialized + scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) + exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) + no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) + + Keyword Args: + drop_rate (float): dropout rate for training (default: 0.0) + global_pool (str): global pool type (default: 'avg') + **: other kwargs are consumed by builder or model __init__() + """ + # Parameters that aren't supported by all models or are intended to only override model defaults if set + # should default to None in command line args/cfg. Remove them if they are present and not set so that + # non-supporting models don't break and default args remain in effect. + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + model_source, model_name = parse_model_name(model_name) + if model_source == 'hf-hub': + assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.' + # For model names specified in the form `hf-hub:path/architecture_name@revision`, + # load model weights + pretrained_cfg from Hugging Face hub. + pretrained_cfg, model_name = load_model_config_from_hf(model_name) + else: + model_name, pretrained_tag = split_model_name_tag(model_name) + if not pretrained_cfg: + # a valid pretrained_cfg argument takes priority over tag in model name + pretrained_cfg = pretrained_tag + + if not is_model(model_name): + raise RuntimeError('Unknown model (%s)' % model_name) + + create_fn = model_entrypoint(model_name) + with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): + model = create_fn( + pretrained=pretrained, + pretrained_cfg=pretrained_cfg, + pretrained_cfg_overlay=pretrained_cfg_overlay, + **kwargs, + ) + + if checkpoint_path: + load_checkpoint(model, checkpoint_path) + + return model diff --git a/timm/models/_features.py b/timm/models/_features.py new file mode 100644 index 00000000..59b080cd --- /dev/null +++ b/timm/models/_features.py @@ -0,0 +1,287 @@ +""" PyTorch Feature Extraction Helpers + +A collection of classes, functions, modules to help extract features from models +and provide a common interface for describing them. + +The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter +https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict, defaultdict +from copy import deepcopy +from functools import partial +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn + + +__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] + + +class FeatureInfo: + + def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): + prev_reduction = 1 + for fi in feature_info: + # sanity check the mandatory fields, there may be additional fields depending on the model + assert 'num_chs' in fi and fi['num_chs'] > 0 + assert 'reduction' in fi and fi['reduction'] >= prev_reduction + prev_reduction = fi['reduction'] + assert 'module' in fi + self.out_indices = out_indices + self.info = feature_info + + def from_other(self, out_indices: Tuple[int]): + return FeatureInfo(deepcopy(self.info), out_indices) + + def get(self, key, idx=None): + """ Get value by key at specified index (indices) + if idx == None, returns value for key at each output index + if idx is an integer, return value for that feature module index (ignoring output indices) + if idx is a list/tupple, return value for each module index (ignoring output indices) + """ + if idx is None: + return [self.info[i][key] for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i][key] for i in idx] + else: + return self.info[idx][key] + + def get_dicts(self, keys=None, idx=None): + """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) + """ + if idx is None: + if keys is None: + return [self.info[i] for i in self.out_indices] + else: + return [{k: self.info[i][k] for k in keys} for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] + else: + return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} + + def channels(self, idx=None): + """ feature channels accessor + """ + return self.get('num_chs', idx) + + def reduction(self, idx=None): + """ feature reduction (output stride) accessor + """ + return self.get('reduction', idx) + + def module_name(self, idx=None): + """ feature module name accessor + """ + return self.get('module', idx) + + def __getitem__(self, item): + return self.info[item] + + def __len__(self): + return len(self.info) + + +class FeatureHooks: + """ Feature Hook Helper + + This module helps with the setup and extraction of hooks for extracting features from + internal nodes in a model by node name. This works quite well in eager Python but needs + redesign for torchscript. + """ + + def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): + # setup feature hooks + modules = {k: v for k, v in named_modules} + for i, h in enumerate(hooks): + hook_name = h['module'] + m = modules[hook_name] + hook_id = out_map[i] if out_map else hook_name + hook_fn = partial(self._collect_output_hook, hook_id) + hook_type = h.get('hook_type', default_hook_type) + if hook_type == 'forward_pre': + m.register_forward_pre_hook(hook_fn) + elif hook_type == 'forward': + m.register_forward_hook(hook_fn) + else: + assert False, "Unsupported hook type" + self._feature_outputs = defaultdict(OrderedDict) + + def _collect_output_hook(self, hook_id, *args): + x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][hook_id] = x + + def get_output(self, device) -> Dict[str, torch.tensor]: + output = self._feature_outputs[device] + self._feature_outputs[device] = OrderedDict() # clear after reading + return output + + +def _module_list(module, flatten_sequential=False): + # a yield/iter would be better for this but wouldn't be compatible with torchscript + ml = [] + for name, module in module.named_children(): + if flatten_sequential and isinstance(module, nn.Sequential): + # first level of Sequential containers is flattened into containing model + for child_name, child_module in module.named_children(): + combined = [name, child_name] + ml.append(('_'.join(combined), '.'.join(combined), child_module)) + else: + ml.append((name, name, module)) + return ml + + +def _get_feature_info(net, out_indices): + feature_info = getattr(net, 'feature_info') + if isinstance(feature_info, FeatureInfo): + return feature_info.from_other(out_indices) + elif isinstance(feature_info, (list, tuple)): + return FeatureInfo(net.feature_info, out_indices) + else: + assert False, "Provided feature_info is not valid" + + +def _get_return_layers(feature_info, out_map): + module_names = feature_info.module_name() + return_layers = {} + for i, name in enumerate(module_names): + return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] + return return_layers + + +class FeatureDictNet(nn.ModuleDict): + """ Feature extractor with OrderedDict return + + Wrap a model and extract features as specified by the out indices, the network is + partially re-built from contained modules. + + There is a strong assumption that the modules have been registered into the model in the same + order as they are used. There should be no reuse of the same nn.Module more than once, including + trivial modules like `self.relu = nn.ReLU`. + + Only submodules that are directly assigned to the model class (`model.feature1`) or at most + one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. + All Sequential containers that are directly assigned to the original model will have their + modules assigned to this module with the name `model.features.1` being changed to `model.features_1` + + Arguments: + model (nn.Module): model from which we will extract the features + out_indices (tuple[int]): model output indices to extract features for + out_map (sequence): list or tuple specifying desired return id for each out index, + otherwise str(index) is used + feature_concat (bool): whether to concatenate intermediate features that are lists or tuples + vs select element [0] + flatten_sequential (bool): whether to flatten sequential modules assigned to model + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureDictNet, self).__init__() + self.feature_info = _get_feature_info(model, out_indices) + self.concat = feature_concat + self.return_layers = {} + return_layers = _get_return_layers(self.feature_info, out_map) + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = set(return_layers.keys()) + layers = OrderedDict() + for new_name, old_name, module in modules: + layers[new_name] = module + if old_name in remaining: + # return id has to be consistently str type for torchscript + self.return_layers[new_name] = str(return_layers[old_name]) + remaining.remove(old_name) + if not remaining: + break + assert not remaining and len(self.return_layers) == len(return_layers), \ + f'Return layers ({remaining}) are not present in model' + self.update(layers) + + def _collect(self, x) -> (Dict[str, torch.Tensor]): + out = OrderedDict() + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + out_id = self.return_layers[name] + if isinstance(x, (tuple, list)): + # If model tap is a tuple or list, concat or select first element + # FIXME this may need to be more generic / flexible for some nets + out[out_id] = torch.cat(x, 1) if self.concat else x[0] + else: + out[out_id] = x + return out + + def forward(self, x) -> Dict[str, torch.Tensor]: + return self._collect(x) + + +class FeatureListNet(FeatureDictNet): + """ Feature extractor with list return + + See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. + In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureListNet, self).__init__( + model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, + flatten_sequential=flatten_sequential) + + def forward(self, x) -> (List[torch.Tensor]): + return list(self._collect(x).values()) + + +class FeatureHookNet(nn.ModuleDict): + """ FeatureHookNet + + Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. + + If `no_rewrite` is True, features are extracted via hooks without modifying the underlying + network in any way. + + If `no_rewrite` is False, the model will be re-written as in the + FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. + + FIXME this does not currently work with Torchscript, see FeatureHooks class + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, + feature_concat=False, flatten_sequential=False, default_hook_type='forward'): + super(FeatureHookNet, self).__init__() + assert not torch.jit.is_scripting() + self.feature_info = _get_feature_info(model, out_indices) + self.out_as_dict = out_as_dict + layers = OrderedDict() + hooks = [] + if no_rewrite: + assert not flatten_sequential + if hasattr(model, 'reset_classifier'): # make sure classifier is removed? + model.reset_classifier(0) + layers['body'] = model + hooks.extend(self.feature_info.get_dicts()) + else: + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type + for f in self.feature_info.get_dicts()} + for new_name, old_name, module in modules: + layers[new_name] = module + for fn, fm in module.named_modules(prefix=old_name): + if fn in remaining: + hooks.append(dict(module=fn, hook_type=remaining[fn])) + del remaining[fn] + if not remaining: + break + assert not remaining, f'Return layers ({remaining}) are not present in model' + self.update(layers) + self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) + + def forward(self, x): + for name, module in self.items(): + x = module(x) + out = self.hooks.get_output(x.device) + return out if self.out_as_dict else list(out.values()) diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py new file mode 100644 index 00000000..10670a1d --- /dev/null +++ b/timm/models/_features_fx.py @@ -0,0 +1,110 @@ +""" PyTorch FX Based Feature Extraction Helpers +Using https://pytorch.org/vision/stable/feature_extraction.html +""" +from typing import Callable, List, Dict, Union, Type + +import torch +from torch import nn + +from ._features import _get_feature_info + +try: + from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor + has_fx_feature_extraction = True +except ImportError: + has_fx_feature_extraction = False + +# Layers we went to treat as leaf modules +from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame +from timm.layers.non_local_attn import BilinearAttnTransform +from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame + +# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here +# BUT modules from timm.models should use the registration mechanism below +_leaf_modules = { + BilinearAttnTransform, # reason: flow control t <= 1 + # Reason: get_same_padding has a max which raises a control flow error + Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, + CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) +} + +try: + from timm.layers import InplaceAbn + _leaf_modules.add(InplaceAbn) +except ImportError: + pass + + +__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor', + 'FeatureGraphNet', 'GraphExtractNet'] + + +def register_notrace_module(module: Type[nn.Module]): + """ + Any module not under timm.models.layers should get this decorator if we don't want to trace through it. + """ + _leaf_modules.add(module) + return module + + +# Functions we want to autowrap (treat them as leaves) +_autowrap_functions = set() + + +def register_notrace_function(func: Callable): + """ + Decorator for functions which ought not to be traced through + """ + _autowrap_functions.add(func) + return func + + +def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' + return _create_feature_extractor( + model, return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} + ) + + +class FeatureGraphNet(nn.Module): + """ A FX Graph based feature extractor that works with the model feature_info metadata + """ + def __init__(self, model, out_indices, out_map=None): + super().__init__() + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' + self.feature_info = _get_feature_info(model, out_indices) + if out_map is not None: + assert len(out_map) == len(out_indices) + return_nodes = { + info['module']: out_map[i] if out_map is not None else info['module'] + for i, info in enumerate(self.feature_info) if i in out_indices} + self.graph_module = create_feature_extractor(model, return_nodes) + + def forward(self, x): + return list(self.graph_module(x).values()) + + +class GraphExtractNet(nn.Module): + """ A standalone feature extraction wrapper that maps dict -> list or single tensor + NOTE: + * one can use feature_extractor directly if dictionary output is desired + * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info + metadata for builtin feature extraction mode + * create_feature_extractor can be used directly if dictionary output is acceptable + + Args: + model: model to extract features from + return_nodes: node names to return features from (dict or list) + squeeze_out: if only one output, and output in list format, flatten to single tensor + """ + def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): + super().__init__() + self.squeeze_out = squeeze_out + self.graph_module = create_feature_extractor(model, return_nodes) + + def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: + out = list(self.graph_module(x).values()) + if self.squeeze_out and len(out) == 1: + return out[0] + return out diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py new file mode 100644 index 00000000..995292aa --- /dev/null +++ b/timm/models/_helpers.py @@ -0,0 +1,115 @@ +""" Model creation / weight loading / state_dict helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import os +from collections import OrderedDict + +import torch + +import timm.models._builder + +_logger = logging.getLogger(__name__) + +__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint'] + + +def clean_state_dict(state_dict): + # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training + cleaned_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] if k.startswith('module.') else k + cleaned_state_dict[name] = v + return cleaned_state_dict + + +def load_state_dict(checkpoint_path, use_ema=True): + if checkpoint_path and os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + state_dict_key = '' + if isinstance(checkpoint, dict): + if use_ema and checkpoint.get('state_dict_ema', None) is not None: + state_dict_key = 'state_dict_ema' + elif use_ema and checkpoint.get('model_ema', None) is not None: + state_dict_key = 'model_ema' + elif 'state_dict' in checkpoint: + state_dict_key = 'state_dict' + elif 'model' in checkpoint: + state_dict_key = 'model' + state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint) + _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) + return state_dict + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False): + if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, 'load_pretrained'): + timm.models._model_builder.load_pretrained(checkpoint_path) + else: + raise NotImplementedError('Model cannot load numpy checkpoint') + return + state_dict = load_state_dict(checkpoint_path, use_ema) + if remap: + state_dict = remap_checkpoint(model, state_dict) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def remap_checkpoint(model, state_dict, allow_reshape=True): + """ remap checkpoint by iterating over state dicts in order (ignoring original keys). + This assumes models (and originating state dict) were created with params registered in same order. + """ + out_dict = {} + for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): + assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + if va.shape != vb.shape: + if allow_reshape: + vb = vb.reshape(va.shape) + else: + assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + out_dict[ka] = vb + return out_dict + + +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): + resume_epoch = None + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + _logger.info('Restoring model state from checkpoint...') + state_dict = clean_state_dict(checkpoint['state_dict']) + model.load_state_dict(state_dict) + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + else: + model.load_state_dict(checkpoint) + if log_info: + _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + diff --git a/timm/models/_hub.py b/timm/models/_hub.py new file mode 100644 index 00000000..e6b7d558 --- /dev/null +++ b/timm/models/_hub.py @@ -0,0 +1,220 @@ +import json +import logging +import os +from functools import partial +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Union + +import torch +from torch.hub import HASH_REGEX, download_url_to_file, urlparse + +try: + from torch.hub import get_dir +except ImportError: + from torch.hub import _get_torch_home as get_dir + +from timm import __version__ +from timm.models._pretrained import filter_pretrained_cfg + +try: + from huggingface_hub import ( + create_repo, get_hf_file_metadata, + hf_hub_download, hf_hub_url, + repo_type_and_id_from_hf_id, upload_folder) + from huggingface_hub.utils import EntryNotFoundError + hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + +_logger = logging.getLogger(__name__) + +__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf', + 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub'] + + +def get_cache_dir(child_dir=''): + """ + Returns the location of the directory where models are cached (and creates it if necessary). + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + hub_dir = get_dir() + child_dir = () if not child_dir else (child_dir,) + model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) + os.makedirs(model_dir, exist_ok=True) + return model_dir + + +def download_cached_file(url, check_hash=True, progress=False): + if isinstance(url, (list, tuple)): + url, filename = url + else: + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(get_cache_dir(), filename) + if not os.path.exists(cached_file): + _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + return cached_file + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def hf_split(hf_id): + # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme + rev_split = hf_id.split('@') + assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' + hf_model_id = rev_split[0] + hf_revision = rev_split[-1] if len(rev_split) > 1 else None + return hf_model_id, hf_revision + + +def load_cfg_from_json(json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + +def _download_from_hf(model_id: str, filename: str): + hf_model_id, hf_revision = hf_split(model_id) + return hf_hub_download(hf_model_id, filename, revision=hf_revision) + + +def load_model_config_from_hf(model_id: str): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, 'config.json') + + hf_config = load_cfg_from_json(cached_file) + if 'pretrained_cfg' not in hf_config: + # old form, pull pretrain_cfg out of the base dict + pretrained_cfg = hf_config + hf_config = {} + hf_config['architecture'] = pretrained_cfg.pop('architecture') + hf_config['num_features'] = pretrained_cfg.pop('num_features', None) + if 'labels' in pretrained_cfg: + hf_config['label_name'] = pretrained_cfg.pop('labels') + hf_config['pretrained_cfg'] = pretrained_cfg + + # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now + pretrained_cfg = hf_config['pretrained_cfg'] + pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation + pretrained_cfg['source'] = 'hf-hub' + if 'num_classes' in hf_config: + # model should be created with parent num_classes if they exist + pretrained_cfg['num_classes'] = hf_config['num_classes'] + model_name = hf_config['architecture'] + + return pretrained_cfg, model_name + + +def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, filename) + state_dict = torch.load(cached_file, map_location='cpu') + return state_dict + + +def save_for_hf(model, save_directory, model_config=None): + assert has_hf_hub(True) + model_config = model_config or {} + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + weights_path = save_directory / 'pytorch_model.bin' + torch.save(model.state_dict(), weights_path) + + config_path = save_directory / 'config.json' + hf_config = {} + pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) + # set some values at root config level + hf_config['architecture'] = pretrained_cfg.pop('architecture') + hf_config['num_classes'] = model_config.get('num_classes', model.num_classes) + hf_config['num_features'] = model_config.get('num_features', model.num_features) + hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None)) + + if 'label' in model_config: + _logger.warning( + "'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. " + "Using provided 'label' field as 'label_name'.") + model_config['label_name'] = model_config.pop('label') + + label_name = model_config.pop('label_name', None) + if label_name: + assert isinstance(label_name, (dict, list, tuple)) + # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages) + # can be a dict id: name if there are id gaps, or tuple/list if no gaps. + hf_config['label_name'] = model_config['label_name'] + + display_name = model_config.pop('display_name', None) + if display_name: + assert isinstance(display_name, dict) + # map label_name -> user interface display name + hf_config['display_name'] = model_config['display_name'] + + hf_config['pretrained_cfg'] = pretrained_cfg + hf_config.update(model_config) + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def push_to_hf_hub( + model, + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_config: Optional[dict] = None, +): + # Create repo if it doesn't exist yet + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) + + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if README file already exist in repo + try: + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf(model, tmpdir, model_config=model_config) + + # Add readme if it does not exist + if not has_readme: + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / "README.md" + readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}' + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) diff --git a/timm/models/_manipulate.py b/timm/models/_manipulate.py new file mode 100644 index 00000000..192979fc --- /dev/null +++ b/timm/models/_manipulate.py @@ -0,0 +1,258 @@ +import collections.abc +import math +import re +from collections import defaultdict +from itertools import chain +from typing import Callable, Union, Dict + +import torch +from torch import nn as nn +from torch.utils.checkpoint import checkpoint + +__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv', + 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq'] + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() + + +def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): + if not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + yield name, module + + +def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False): + if module._parameters and not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules_with_params( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if module._parameters and depth_first and include_root: + yield name, module + + +MATCH_PREV_GROUP = (99999,) + + +def group_with_matcher( + named_objects, + group_matcher: Union[Dict, Callable], + output_values: bool = False, + reverse: bool = False +): + if isinstance(group_matcher, dict): + # dictionary matcher contains a dict of raw-string regex expr that must be compiled + compiled = [] + for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): + if mspec is None: + continue + # map all matching specifications into 3-tuple (compiled re, prefix, suffix) + if isinstance(mspec, (tuple, list)): + # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) + for sspec in mspec: + compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] + else: + compiled += [(re.compile(mspec), (group_ordinal,), None)] + group_matcher = compiled + + def _get_grouping(name): + if isinstance(group_matcher, (list, tuple)): + for match_fn, prefix, suffix in group_matcher: + r = match_fn.match(name) + if r: + parts = (prefix, r.groups(), suffix) + # map all tuple elem to int for numeric sort, filter out None entries + return tuple(map(float, chain.from_iterable(filter(None, parts)))) + return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal + else: + ord = group_matcher(name) + if not isinstance(ord, collections.abc.Iterable): + return ord, + return tuple(ord) + + # map layers into groups via ordinals (ints or tuples of ints) from matcher + grouping = defaultdict(list) + for k, v in named_objects: + grouping[_get_grouping(k)].append(v if output_values else k) + + # remap to integers + layer_id_to_param = defaultdict(list) + lid = -1 + for k in sorted(filter(lambda x: x is not None, grouping.keys())): + if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: + lid += 1 + layer_id_to_param[lid].extend(grouping[k]) + + if reverse: + assert not output_values, "reverse mapping only sensible for name output" + # output reverse mapping + param_to_layer_id = {} + for lid, lm in layer_id_to_param.items(): + for n in lm: + param_to_layer_id[n] = lid + return param_to_layer_id + + return layer_id_to_param + + +def group_parameters( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) + + +def group_modules( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) + + +def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'): + prefix_is_tuple = isinstance(prefix, tuple) + if isinstance(module_types, str): + if module_types == 'container': + module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict) + else: + module_types = (nn.Sequential,) + for name, module in named_modules: + if depth and isinstance(module, module_types): + yield from flatten_modules( + module.named_children(), + depth - 1, + prefix=(name,) if prefix_is_tuple else name, + module_types=module_types, + ) + else: + if prefix_is_tuple: + name = prefix + (name,) + yield name, module + else: + if prefix: + name = '.'.join([prefix, name]) + yield name, module + + +def checkpoint_seq( + functions, + x, + every=1, + flatten=False, + skip_last=False, + preserve_rng_state=True +): + r"""A helper function for checkpointing sequential models. + + Sequential models execute a list of modules/functions in order + (sequentially). Therefore, we can divide such a sequence into segments + and checkpoint each segment. All segments except run in :func:`torch.no_grad` + manner, i.e., not storing the intermediate activations. The inputs of each + checkpointed segment will be saved for re-running the segment in the backward pass. + + See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. + + .. warning:: + Checkpointing currently only supports :func:`torch.autograd.backward` + and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` + is not supported. + + .. warning: + At least one of the inputs needs to have :code:`requires_grad=True` if + grads are needed for model inputs, otherwise the checkpointed part of the + model won't have gradients. + + Args: + functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. + x: A Tensor that is input to :attr:`functions` + every: checkpoint every-n functions (default: 1) + flatten (bool): flatten nn.Sequential of nn.Sequentials + skip_last (bool): skip checkpointing the last function in the sequence if True + preserve_rng_state (bool, optional, default=True): Omit stashing and restoring + the RNG state during each checkpoint. + + Returns: + Output of running :attr:`functions` sequentially on :attr:`*inputs` + + Example: + >>> model = nn.Sequential(...) + >>> input_var = checkpoint_seq(model, input_var, every=2) + """ + def run_function(start, end, functions): + def forward(_x): + for j in range(start, end + 1): + _x = functions[j](_x) + return _x + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = functions.children() + if flatten: + functions = chain.from_iterable(functions) + if not isinstance(functions, (tuple, list)): + functions = tuple(functions) + + num_checkpointed = len(functions) + if skip_last: + num_checkpointed -= 1 + end = -1 + for start in range(0, num_checkpointed, every): + end = min(start + every - 1, num_checkpointed - 1) + x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) + if skip_last: + return run_function(end + 1, len(functions) - 1, functions)(x) + return x + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight diff --git a/timm/models/pretrained.py b/timm/models/_pretrained.py similarity index 97% rename from timm/models/pretrained.py rename to timm/models/_pretrained.py index 2ca7ac5a..b5ecbc50 100644 --- a/timm/models/pretrained.py +++ b/timm/models/_pretrained.py @@ -4,6 +4,9 @@ from dataclasses import dataclass, field, replace, asdict from typing import Any, Deque, Dict, Tuple, Optional, Union +__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs'] + + @dataclass class PretrainedCfg: """ diff --git a/timm/models/_prune.py b/timm/models/_prune.py new file mode 100644 index 00000000..4e744dec --- /dev/null +++ b/timm/models/_prune.py @@ -0,0 +1,113 @@ +import os +from copy import deepcopy + +from torch import nn as nn + +from timm.layers import Conv2dSame, BatchNormAct2d, Linear + +__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file'] + + +def extract_layer(model, layer): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + if not hasattr(model, 'module') and layer[0] == 'module': + layer = layer[1:] + for l in layer: + if hasattr(module, l): + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + else: + return module + return module + + +def set_layer(model, layer, val): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + lst_index = 0 + module2 = module + for l in layer: + if hasattr(module2, l): + if not l.isdigit(): + module2 = getattr(module2, l) + else: + module2 = module2[int(l)] + lst_index += 1 + lst_index -= 1 + for l in layer[:lst_index]: + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + l = layer[lst_index] + setattr(module, l, val) + + +def adapt_model_from_string(parent_module, model_string): + separator = '***' + state_dict = {} + lst_shape = model_string.split(separator) + for k in lst_shape: + k = k.split(':') + key = k[0] + shape = k[1][1:-1].split(',') + if shape[0] != '': + state_dict[key] = [int(i) for i in shape] + + new_module = deepcopy(parent_module) + for n, m in parent_module.named_modules(): + old_module = extract_layer(parent_module, n) + if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): + if isinstance(old_module, Conv2dSame): + conv = Conv2dSame + else: + conv = nn.Conv2d + s = state_dict[n + '.weight'] + in_channels = s[1] + out_channels = s[0] + g = 1 + if old_module.groups > 1: + in_channels = out_channels + g = in_channels + new_conv = conv( + in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, + bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, + groups=g, stride=old_module.stride) + set_layer(new_module, n, new_conv) + elif isinstance(old_module, BatchNormAct2d): + new_bn = BatchNormAct2d( + state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + new_bn.drop = old_module.drop + new_bn.act = old_module.act + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.BatchNorm2d): + new_bn = nn.BatchNorm2d( + num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.Linear): + # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? + num_features = state_dict[n + '.weight'][1] + new_fc = Linear( + in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) + set_layer(new_module, n, new_fc) + if hasattr(new_module, 'num_features'): + new_module.num_features = num_features + new_module.eval() + parent_module.eval() + + return new_module + + +def adapt_model_from_file(parent_module, model_variant): + adapt_file = os.path.join(os.path.dirname(__file__), '_pruned', model_variant + '.txt') + with open(adapt_file, 'r') as f: + return adapt_model_from_string(parent_module, f.read().strip()) diff --git a/timm/models/pruned/ecaresnet101d_pruned.txt b/timm/models/_pruned/ecaresnet101d_pruned.txt similarity index 100% rename from timm/models/pruned/ecaresnet101d_pruned.txt rename to timm/models/_pruned/ecaresnet101d_pruned.txt diff --git a/timm/models/pruned/ecaresnet50d_pruned.txt b/timm/models/_pruned/ecaresnet50d_pruned.txt similarity index 100% rename from timm/models/pruned/ecaresnet50d_pruned.txt rename to timm/models/_pruned/ecaresnet50d_pruned.txt diff --git a/timm/models/pruned/efficientnet_b1_pruned.txt b/timm/models/_pruned/efficientnet_b1_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b1_pruned.txt rename to timm/models/_pruned/efficientnet_b1_pruned.txt diff --git a/timm/models/pruned/efficientnet_b2_pruned.txt b/timm/models/_pruned/efficientnet_b2_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b2_pruned.txt rename to timm/models/_pruned/efficientnet_b2_pruned.txt diff --git a/timm/models/pruned/efficientnet_b3_pruned.txt b/timm/models/_pruned/efficientnet_b3_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b3_pruned.txt rename to timm/models/_pruned/efficientnet_b3_pruned.txt diff --git a/timm/models/_registry.py b/timm/models/_registry.py new file mode 100644 index 00000000..fc7b3437 --- /dev/null +++ b/timm/models/_registry.py @@ -0,0 +1,212 @@ +""" Model Registry +Hacked together by / Copyright 2020 Ross Wightman +""" + +import fnmatch +import re +import sys +from collections import defaultdict, deque +from copy import deepcopy +from typing import List, Optional, Union, Tuple + +from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag + +__all__ = [ + 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', + 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] + +_module_to_models = defaultdict(set) # dict of sets to check membership of model in module +_model_to_module = {} # mapping of model names to module names +_model_entrypoints = {} # mapping of model names to architecture entrypoint fns +_model_has_pretrained = set() # set of model names that have pretrained weight url present +_model_default_cfgs = dict() # central repo for model arch -> default cfg objects +_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs +_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names + + +def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]: + return split_model_name_tag(model_name)[0] + + +def register_model(fn): + # lookup containing module + mod = sys.modules[fn.__module__] + module_name_split = fn.__module__.split('.') + module_name = module_name_split[-1] if len(module_name_split) else '' + + # add model to __all__ in module + model_name = fn.__name__ + if hasattr(mod, '__all__'): + mod.__all__.append(model_name) + else: + mod.__all__ = [model_name] + + # add entries to registry dict/sets + _model_entrypoints[model_name] = fn + _model_to_module[model_name] = module_name + _module_to_models[module_name].add(model_name) + if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: + # this will catch all models that have entrypoint matching cfg key, but miss any aliasing + # entrypoints or non-matching combos + cfg = mod.default_cfgs[model_name] + if not isinstance(cfg, DefaultCfg): + # new style default cfg dataclass w/ multiple entries per model-arch + assert isinstance(cfg, dict) + # old style cfg dict per model-arch + cfg = PretrainedCfg(**cfg) + cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg}) + + for tag_idx, tag in enumerate(cfg.tags): + is_default = tag_idx == 0 + pretrained_cfg = cfg.cfgs[tag] + if is_default: + _model_pretrained_cfgs[model_name] = pretrained_cfg + if pretrained_cfg.has_weights: + # add tagless entry if it's default and has weights + _model_has_pretrained.add(model_name) + if tag: + model_name_tag = '.'.join([model_name, tag]) + _model_pretrained_cfgs[model_name_tag] = pretrained_cfg + if pretrained_cfg.has_weights: + # add model w/ tag if tag is valid + _model_has_pretrained.add(model_name_tag) + _model_with_tags[model_name].append(model_name_tag) + else: + _model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances) + + _model_default_cfgs[model_name] = cfg + + return fn + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def list_models( + filter: Union[str, List[str]] = '', + module: str = '', + pretrained=False, + exclude_filters: str = '', + name_matches_cfg: bool = False, + include_tags: Optional[bool] = None, +): + """ Return list of available model names, sorted alphabetically + + Args: + filter (str) - Wildcard filter string that works with fnmatch + module (str) - Limit model selection to a specific submodule (ie 'vision_transformer') + pretrained (bool) - Include only models with valid pretrained weights if True + exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) + include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults + set to True when pretrained=True else False (default: None) + Example: + model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' + model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module + """ + if include_tags is None: + # FIXME should this be default behaviour? or default to include_tags=True? + include_tags = pretrained + + if module: + all_models = list(_module_to_models[module]) + else: + all_models = _model_entrypoints.keys() + + if include_tags: + # expand model names to include names w/ pretrained tags + models_with_tags = [] + for m in all_models: + models_with_tags.extend(_model_with_tags[m]) + all_models = models_with_tags + + if filter: + models = [] + include_filters = filter if isinstance(filter, (tuple, list)) else [filter] + for f in include_filters: + include_models = fnmatch.filter(all_models, f) # include these models + if len(include_models): + models = set(models).union(include_models) + else: + models = all_models + + if exclude_filters: + if not isinstance(exclude_filters, (tuple, list)): + exclude_filters = [exclude_filters] + for xf in exclude_filters: + exclude_models = fnmatch.filter(models, xf) # exclude these models + if len(exclude_models): + models = set(models).difference(exclude_models) + + if pretrained: + models = _model_has_pretrained.intersection(models) + + if name_matches_cfg: + models = set(_model_pretrained_cfgs).intersection(models) + + return list(sorted(models, key=_natural_key)) + + +def list_pretrained( + filter: Union[str, List[str]] = '', + exclude_filters: str = '', +): + return list_models( + filter=filter, + pretrained=True, + exclude_filters=exclude_filters, + include_tags=True, + ) + + +def is_model(model_name): + """ Check if a model name exists + """ + arch_name = get_arch_name(model_name) + return arch_name in _model_entrypoints + + +def model_entrypoint(model_name, module_filter: Optional[str] = None): + """Fetch a model entrypoint for specified model name + """ + arch_name = get_arch_name(model_name) + if module_filter and arch_name not in _module_to_models.get(module_filter, {}): + raise RuntimeError(f'Model ({model_name} not found in module {module_filter}.') + return _model_entrypoints[arch_name] + + +def list_modules(): + """ Return list of module names that contain models / model entrypoints + """ + modules = _module_to_models.keys() + return list(sorted(modules)) + + +def is_model_in_modules(model_name, module_names): + """Check if a model exists within a subset of modules + Args: + model_name (str) - name of model to check + module_names (tuple, list, set) - names of modules to search in + """ + arch_name = get_arch_name(model_name) + assert isinstance(module_names, (tuple, list, set)) + return any(arch_name in _module_to_models[n] for n in module_names) + + +def is_model_pretrained(model_name): + return model_name in _model_has_pretrained + + +def get_pretrained_cfg(model_name): + if model_name in _model_pretrained_cfgs: + return deepcopy(_model_pretrained_cfgs[model_name]) + raise RuntimeError(f'No pretrained config exists for model {model_name}.') + + +def get_pretrained_cfg_value(model_name, cfg_key): + """ Get a specific model default_cfg value by key. None if key doesn't exist. + """ + if model_name in _model_pretrained_cfgs: + return getattr(_model_pretrained_cfgs[model_name], cfg_key, None) + raise RuntimeError(f'No pretrained config exist for model {model_name}.') \ No newline at end of file diff --git a/timm/models/beit.py b/timm/models/beit.py index c44256a3..de71f441 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -61,12 +61,14 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from .helpers import build_model_with_cfg -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ -from .pretrained import generate_default_cfgs -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._pretrained import generate_default_cfgs +from ._registry import register_model from .vision_transformer import checkpoint_filter_fn +__all__ = ['Beit'] + def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 3815fa30..c67144cc 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -13,9 +13,9 @@ Consider all of the models definitions here as experimental WIP and likely to ch Hacked together by / copyright Ross Wightman, 2021. """ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from ._builder import build_model_with_cfg +from ._registry import register_model from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks -from .helpers import build_model_with_cfg -from .registry import register_model __all__ = [] diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 1e402629..0e5c9c7f 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -26,18 +26,18 @@ Hacked together by / copyright Ross Wightman, 2021. """ import math from dataclasses import dataclass, field, replace -from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence from functools import partial +from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\ - EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d -from .registry import register_model +from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ + create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] diff --git a/timm/models/cait.py b/timm/models/cait.py index c0892099..15dcd956 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -8,17 +8,16 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W """ # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ -from .registry import register_model - +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn'] diff --git a/timm/models/coat.py b/timm/models/coat.py index c3071a6c..4ed6d8e8 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -7,7 +7,6 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT Modified from timm/models/vision_transformer.py """ -from copy import deepcopy from functools import partial from typing import Tuple, List, Union @@ -16,19 +15,11 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ -from .registry import register_model -from .layers import _assert - - -__all__ = [ - "coat_tiny", - "coat_mini", - "coat_lite_tiny", - "coat_lite_mini", - "coat_lite_small" -] +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['CoaT'] def _cfg_coat(url='', **kwargs): diff --git a/timm/models/convit.py b/timm/models/convit.py index 26849f6e..d117ccdc 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -22,20 +22,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ''' +from functools import partial + import torch import torch.nn as nn -from functools import partial -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .vision_transformer_hybrid import HybridEmbed -from .fx_features import register_notrace_module -import torch -import torch.nn as nn + +__all__ = ['ConViT'] def _cfg(url='', **kwargs): diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index e7e2481a..3a8c6cf5 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -5,9 +5,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import SelectAdaptivePool2d +from timm.layers import SelectAdaptivePool2d +from ._registry import register_model +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq + +__all__ = ['ConvMixer'] def _cfg(url='', **kwargs): diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 36a484b3..eea5782a 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -18,12 +18,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ +from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ create_conv2d, get_act_layer, make_divisible, to_ntuple -from .pretrained import generate_default_cfgs -from .registry import register_model - +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 764eb3fe..908fcf6d 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -24,21 +24,22 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ +from functools import partial +from typing import List from typing import Tuple import torch -import torch.nn as nn -import torch.nn.functional as F import torch.hub -from functools import partial -from typing import List +import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, trunc_normal_, _assert -from .registry import register_model -from .vision_transformer import Mlp, Block +from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model +from .vision_transformer import Block + +__all__ = ['CrossViT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 2c09e7e3..280f929e 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -12,20 +12,18 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage Hacked together by / Copyright 2020 Ross Wightman """ -import collections.abc -from dataclasses import dataclass, field, asdict +from dataclasses import dataclass, asdict from functools import partial -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP -from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible -from .registry import register_model - +from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, MATCH_PREV_GROUP +from ._registry import register_model __all__ = ['CspNet'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/deit.py b/timm/models/deit.py index 3205b024..24fbbe56 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -17,9 +17,11 @@ from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model +__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 1afdfd7b..e731f7b0 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -4,7 +4,6 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool. """ import re from collections import OrderedDict -from functools import partial import torch import torch.nn as nn @@ -13,9 +12,10 @@ import torch.utils.checkpoint as cp from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, MATCH_PREV_GROUP -from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier -from .registry import register_model +from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import MATCH_PREV_GROUP +from ._registry import register_model __all__ = ['DenseNet'] diff --git a/timm/models/dla.py b/timm/models/dla.py index 0ab807c0..204fcb4b 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -13,9 +13,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['DLA'] diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 95159729..87bd918f 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -15,9 +15,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier -from .registry import register_model +from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['DPN'] diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 422d4f2c..d90471fb 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -8,20 +8,20 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt Modifications and additions for timm by / Copyright 2022, Ross Wightman """ import math -import torch from collections import OrderedDict from functools import partial from typing import Tuple -from torch import nn +import torch import torch.nn.functional as F +from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module -from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d -from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .registry import register_model - +from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model __all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 4749d93a..4f33f29a 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -18,9 +18,11 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, trunc_normal_, to_2tuple, Mlp -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 3c0efc96..a1324ae3 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -42,15 +42,15 @@ import torch import torch.nn as nn import torch.nn.functional as F - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ +from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq -from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct -from .registry import register_model +from ._features import FeatureInfo, FeatureHooks +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['EfficientNet', 'EfficientNetFeatures'] diff --git a/timm/models/factory.py b/timm/models/factory.py index 9e06c1aa..0ae83dc0 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -1,100 +1,4 @@ -import os -from typing import Any, Dict, Optional, Union -from urllib.parse import urlsplit +from ._factory import * -from .pretrained import PretrainedCfg, split_model_name_tag -from .helpers import load_checkpoint -from .hub import load_model_config_from_hf -from .layers import set_layer_config -from .registry import is_model, model_entrypoint - - -def parse_model_name(model_name): - if model_name.startswith('hf_hub'): - # NOTE for backwards compat, deprecate hf_hub use - model_name = model_name.replace('hf_hub', 'hf-hub') - parsed = urlsplit(model_name) - assert parsed.scheme in ('', 'timm', 'hf-hub') - if parsed.scheme == 'hf-hub': - # FIXME may use fragment as revision, currently `@` in URI path - return parsed.scheme, parsed.path - else: - model_name = os.path.split(parsed.path)[-1] - return 'timm', model_name - - -def safe_model_name(model_name, remove_source=True): - # return a filename / path safe model name - def make_safe(name): - return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') - if remove_source: - model_name = parse_model_name(model_name)[-1] - return make_safe(model_name) - - -def create_model( - model_name: str, - pretrained: bool = False, - pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, - pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, - checkpoint_path: str = '', - scriptable: Optional[bool] = None, - exportable: Optional[bool] = None, - no_jit: Optional[bool] = None, - **kwargs, -): - """Create a model - - Lookup model's entrypoint function and pass relevant args to create a new model. - - **kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg() - and then the model class __init__(). kwargs values set to None are pruned before passing. - - Args: - model_name (str): name of model to instantiate - pretrained (bool): load pretrained ImageNet-1k weights if true - pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model - pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these - checkpoint_path (str): path of checkpoint to load _after_ the model is initialized - scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) - exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) - no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) - - Keyword Args: - drop_rate (float): dropout rate for training (default: 0.0) - global_pool (str): global pool type (default: 'avg') - **: other kwargs are consumed by builder or model __init__() - """ - # Parameters that aren't supported by all models or are intended to only override model defaults if set - # should default to None in command line args/cfg. Remove them if they are present and not set so that - # non-supporting models don't break and default args remain in effect. - kwargs = {k: v for k, v in kwargs.items() if v is not None} - - model_source, model_name = parse_model_name(model_name) - if model_source == 'hf-hub': - assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.' - # For model names specified in the form `hf-hub:path/architecture_name@revision`, - # load model weights + pretrained_cfg from Hugging Face hub. - pretrained_cfg, model_name = load_model_config_from_hf(model_name) - else: - model_name, pretrained_tag = split_model_name_tag(model_name) - if not pretrained_cfg: - # a valid pretrained_cfg argument takes priority over tag in model name - pretrained_cfg = pretrained_tag - - if not is_model(model_name): - raise RuntimeError('Unknown model (%s)' % model_name) - - create_fn = model_entrypoint(model_name) - with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): - model = create_fn( - pretrained=pretrained, - pretrained_cfg=pretrained_cfg, - pretrained_cfg_overlay=pretrained_cfg_overlay, - **kwargs, - ) - - if checkpoint_path: - load_checkpoint(model, checkpoint_path) - - return model +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/features.py b/timm/models/features.py index 0bc46419..25605d99 100644 --- a/timm/models/features.py +++ b/timm/models/features.py @@ -1,284 +1,4 @@ -""" PyTorch Feature Extraction Helpers +from ._features import * -A collection of classes, functions, modules to help extract features from models -and provide a common interface for describing them. - -The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter -https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py - -Hacked together by / Copyright 2020 Ross Wightman -""" -from collections import OrderedDict, defaultdict -from copy import deepcopy -from functools import partial -from typing import Dict, List, Tuple - -import torch -import torch.nn as nn - - -class FeatureInfo: - - def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): - prev_reduction = 1 - for fi in feature_info: - # sanity check the mandatory fields, there may be additional fields depending on the model - assert 'num_chs' in fi and fi['num_chs'] > 0 - assert 'reduction' in fi and fi['reduction'] >= prev_reduction - prev_reduction = fi['reduction'] - assert 'module' in fi - self.out_indices = out_indices - self.info = feature_info - - def from_other(self, out_indices: Tuple[int]): - return FeatureInfo(deepcopy(self.info), out_indices) - - def get(self, key, idx=None): - """ Get value by key at specified index (indices) - if idx == None, returns value for key at each output index - if idx is an integer, return value for that feature module index (ignoring output indices) - if idx is a list/tupple, return value for each module index (ignoring output indices) - """ - if idx is None: - return [self.info[i][key] for i in self.out_indices] - if isinstance(idx, (tuple, list)): - return [self.info[i][key] for i in idx] - else: - return self.info[idx][key] - - def get_dicts(self, keys=None, idx=None): - """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) - """ - if idx is None: - if keys is None: - return [self.info[i] for i in self.out_indices] - else: - return [{k: self.info[i][k] for k in keys} for i in self.out_indices] - if isinstance(idx, (tuple, list)): - return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] - else: - return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} - - def channels(self, idx=None): - """ feature channels accessor - """ - return self.get('num_chs', idx) - - def reduction(self, idx=None): - """ feature reduction (output stride) accessor - """ - return self.get('reduction', idx) - - def module_name(self, idx=None): - """ feature module name accessor - """ - return self.get('module', idx) - - def __getitem__(self, item): - return self.info[item] - - def __len__(self): - return len(self.info) - - -class FeatureHooks: - """ Feature Hook Helper - - This module helps with the setup and extraction of hooks for extracting features from - internal nodes in a model by node name. This works quite well in eager Python but needs - redesign for torchscript. - """ - - def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): - # setup feature hooks - modules = {k: v for k, v in named_modules} - for i, h in enumerate(hooks): - hook_name = h['module'] - m = modules[hook_name] - hook_id = out_map[i] if out_map else hook_name - hook_fn = partial(self._collect_output_hook, hook_id) - hook_type = h.get('hook_type', default_hook_type) - if hook_type == 'forward_pre': - m.register_forward_pre_hook(hook_fn) - elif hook_type == 'forward': - m.register_forward_hook(hook_fn) - else: - assert False, "Unsupported hook type" - self._feature_outputs = defaultdict(OrderedDict) - - def _collect_output_hook(self, hook_id, *args): - x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre - if isinstance(x, tuple): - x = x[0] # unwrap input tuple - self._feature_outputs[x.device][hook_id] = x - - def get_output(self, device) -> Dict[str, torch.tensor]: - output = self._feature_outputs[device] - self._feature_outputs[device] = OrderedDict() # clear after reading - return output - - -def _module_list(module, flatten_sequential=False): - # a yield/iter would be better for this but wouldn't be compatible with torchscript - ml = [] - for name, module in module.named_children(): - if flatten_sequential and isinstance(module, nn.Sequential): - # first level of Sequential containers is flattened into containing model - for child_name, child_module in module.named_children(): - combined = [name, child_name] - ml.append(('_'.join(combined), '.'.join(combined), child_module)) - else: - ml.append((name, name, module)) - return ml - - -def _get_feature_info(net, out_indices): - feature_info = getattr(net, 'feature_info') - if isinstance(feature_info, FeatureInfo): - return feature_info.from_other(out_indices) - elif isinstance(feature_info, (list, tuple)): - return FeatureInfo(net.feature_info, out_indices) - else: - assert False, "Provided feature_info is not valid" - - -def _get_return_layers(feature_info, out_map): - module_names = feature_info.module_name() - return_layers = {} - for i, name in enumerate(module_names): - return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] - return return_layers - - -class FeatureDictNet(nn.ModuleDict): - """ Feature extractor with OrderedDict return - - Wrap a model and extract features as specified by the out indices, the network is - partially re-built from contained modules. - - There is a strong assumption that the modules have been registered into the model in the same - order as they are used. There should be no reuse of the same nn.Module more than once, including - trivial modules like `self.relu = nn.ReLU`. - - Only submodules that are directly assigned to the model class (`model.feature1`) or at most - one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. - All Sequential containers that are directly assigned to the original model will have their - modules assigned to this module with the name `model.features.1` being changed to `model.features_1` - - Arguments: - model (nn.Module): model from which we will extract the features - out_indices (tuple[int]): model output indices to extract features for - out_map (sequence): list or tuple specifying desired return id for each out index, - otherwise str(index) is used - feature_concat (bool): whether to concatenate intermediate features that are lists or tuples - vs select element [0] - flatten_sequential (bool): whether to flatten sequential modules assigned to model - """ - def __init__( - self, model, - out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): - super(FeatureDictNet, self).__init__() - self.feature_info = _get_feature_info(model, out_indices) - self.concat = feature_concat - self.return_layers = {} - return_layers = _get_return_layers(self.feature_info, out_map) - modules = _module_list(model, flatten_sequential=flatten_sequential) - remaining = set(return_layers.keys()) - layers = OrderedDict() - for new_name, old_name, module in modules: - layers[new_name] = module - if old_name in remaining: - # return id has to be consistently str type for torchscript - self.return_layers[new_name] = str(return_layers[old_name]) - remaining.remove(old_name) - if not remaining: - break - assert not remaining and len(self.return_layers) == len(return_layers), \ - f'Return layers ({remaining}) are not present in model' - self.update(layers) - - def _collect(self, x) -> (Dict[str, torch.Tensor]): - out = OrderedDict() - for name, module in self.items(): - x = module(x) - if name in self.return_layers: - out_id = self.return_layers[name] - if isinstance(x, (tuple, list)): - # If model tap is a tuple or list, concat or select first element - # FIXME this may need to be more generic / flexible for some nets - out[out_id] = torch.cat(x, 1) if self.concat else x[0] - else: - out[out_id] = x - return out - - def forward(self, x) -> Dict[str, torch.Tensor]: - return self._collect(x) - - -class FeatureListNet(FeatureDictNet): - """ Feature extractor with list return - - See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. - In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. - """ - def __init__( - self, model, - out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): - super(FeatureListNet, self).__init__( - model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, - flatten_sequential=flatten_sequential) - - def forward(self, x) -> (List[torch.Tensor]): - return list(self._collect(x).values()) - - -class FeatureHookNet(nn.ModuleDict): - """ FeatureHookNet - - Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. - - If `no_rewrite` is True, features are extracted via hooks without modifying the underlying - network in any way. - - If `no_rewrite` is False, the model will be re-written as in the - FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. - - FIXME this does not currently work with Torchscript, see FeatureHooks class - """ - def __init__( - self, model, - out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, - feature_concat=False, flatten_sequential=False, default_hook_type='forward'): - super(FeatureHookNet, self).__init__() - assert not torch.jit.is_scripting() - self.feature_info = _get_feature_info(model, out_indices) - self.out_as_dict = out_as_dict - layers = OrderedDict() - hooks = [] - if no_rewrite: - assert not flatten_sequential - if hasattr(model, 'reset_classifier'): # make sure classifier is removed? - model.reset_classifier(0) - layers['body'] = model - hooks.extend(self.feature_info.get_dicts()) - else: - modules = _module_list(model, flatten_sequential=flatten_sequential) - remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type - for f in self.feature_info.get_dicts()} - for new_name, old_name, module in modules: - layers[new_name] = module - for fn, fm in module.named_modules(prefix=old_name): - if fn in remaining: - hooks.append(dict(module=fn, hook_type=remaining[fn])) - del remaining[fn] - if not remaining: - break - assert not remaining, f'Return layers ({remaining}) are not present in model' - self.update(layers) - self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) - - def forward(self, x): - for name, module in self.items(): - x = module(x) - out = self.hooks.get_output(x.device) - return out if self.out_as_dict else list(out.values()) +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index b09381b7..0ff3a18b 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -1,106 +1,4 @@ -""" PyTorch FX Based Feature Extraction Helpers -Using https://pytorch.org/vision/stable/feature_extraction.html -""" -from typing import Callable, List, Dict, Union, Type +from ._features_fx import * -import torch -from torch import nn - -from .features import _get_feature_info - -try: - from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor - has_fx_feature_extraction = True -except ImportError: - has_fx_feature_extraction = False - -# Layers we went to treat as leaf modules -from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame -from .layers.non_local_attn import BilinearAttnTransform -from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame - -# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here -# BUT modules from timm.models should use the registration mechanism below -_leaf_modules = { - BilinearAttnTransform, # reason: flow control t <= 1 - # Reason: get_same_padding has a max which raises a control flow error - Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, - CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) -} - -try: - from .layers import InplaceAbn - _leaf_modules.add(InplaceAbn) -except ImportError: - pass - - -def register_notrace_module(module: Type[nn.Module]): - """ - Any module not under timm.models.layers should get this decorator if we don't want to trace through it. - """ - _leaf_modules.add(module) - return module - - -# Functions we want to autowrap (treat them as leaves) -_autowrap_functions = set() - - -def register_notrace_function(func: Callable): - """ - Decorator for functions which ought not to be traced through - """ - _autowrap_functions.add(func) - return func - - -def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): - assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' - return _create_feature_extractor( - model, return_nodes, - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} - ) - - -class FeatureGraphNet(nn.Module): - """ A FX Graph based feature extractor that works with the model feature_info metadata - """ - def __init__(self, model, out_indices, out_map=None): - super().__init__() - assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' - self.feature_info = _get_feature_info(model, out_indices) - if out_map is not None: - assert len(out_map) == len(out_indices) - return_nodes = { - info['module']: out_map[i] if out_map is not None else info['module'] - for i, info in enumerate(self.feature_info) if i in out_indices} - self.graph_module = create_feature_extractor(model, return_nodes) - - def forward(self, x): - return list(self.graph_module(x).values()) - - -class GraphExtractNet(nn.Module): - """ A standalone feature extraction wrapper that maps dict -> list or single tensor - NOTE: - * one can use feature_extractor directly if dictionary output is desired - * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info - metadata for builtin feature extraction mode - * create_feature_extractor can be used directly if dictionary output is acceptable - - Args: - model: model to extract features from - return_nodes: node names to return features from (dict or list) - squeeze_out: if only one output, and output in list format, flatten to single tensor - """ - def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): - super().__init__() - self.squeeze_out = squeeze_out - self.graph_module = create_feature_extractor(model, return_nodes) - - def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: - out = list(self.graph_module(x).values()) - if self.squeeze_out and len(out) == 1: - return out[0] - return out +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index fb375e2c..ec9b7e5e 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -28,12 +28,13 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\ +from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ get_attn, get_act_layer, get_norm_layer, _assert -from .registry import register_model -from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply +from ._registry import register_model +from .vision_transformer_relpos import RelPosBias # FIXME move to common location __all__ = ['GlobalContextVit'] diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index e19af88b..492049b9 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -11,13 +11,12 @@ import torch import torch.nn as nn import torch.nn.functional as F - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import SelectAdaptivePool2d, Linear, make_divisible -from .efficientnet_blocks import SqueezeExcite, ConvBnAct -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model - +from timm.layers import SelectAdaptivePool2d, Linear, make_divisible +from ._builder import build_model_with_cfg +from ._efficientnet_blocks import SqueezeExcite, ConvBnAct +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['GhostNet'] diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index a1e73554..2b4131fb 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -5,11 +5,13 @@ by Ross Wightman """ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SEModule -from .registry import register_model +from timm.layers import SEModule +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet, Bottleneck, BasicBlock +__all__ = [] + def _cfg(url='', **kwargs): return { diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index a9c946b2..b487d0fd 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -13,9 +13,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier, get_padding -from .registry import register_model +from timm.layers import create_classifier, get_padding +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['Xception65'] diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py index 132eeab4..d77e642a 100644 --- a/timm/models/hardcorenas.py +++ b/timm/models/hardcorenas.py @@ -3,12 +3,14 @@ from functools import partial import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels -from .helpers import build_model_with_cfg, pretrained_cfg_for_features -from .layers import get_act_fn +from ._builder import build_model_with_cfg +from ._builder import pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels +from ._registry import register_model from .mobilenetv3 import MobileNetV3, MobileNetV3Features -from .registry import register_model + +__all__ = [] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 2a5551e0..6bc82eb8 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -1,855 +1,7 @@ -""" Model creation / weight loading / state_dict helpers +from ._builder import * +from ._helpers import * +from ._manipulate import * +from ._prune import * -Hacked together by / Copyright 2020 Ross Wightman -""" -import collections.abc -import dataclasses -import logging -import math -import os -import re -from collections import OrderedDict, defaultdict -from copy import deepcopy -from itertools import chain -from typing import Any, Callable, Optional, Tuple, Dict, Union - -import torch -import torch.nn as nn -from torch.hub import load_state_dict_from_url -from torch.utils.checkpoint import checkpoint - -from .pretrained import PretrainedCfg -from .features import FeatureListNet, FeatureDictNet, FeatureHookNet -from .fx_features import FeatureGraphNet -from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf -from .layers import Conv2dSame, Linear, BatchNormAct2d -from .registry import get_pretrained_cfg - - -_logger = logging.getLogger(__name__) - - -# Global variables for rarely used pretrained checkpoint download progress and hash check. -# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle. -_DOWNLOAD_PROGRESS = False -_CHECK_HASH = False - - -def clean_state_dict(state_dict): - # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training - cleaned_state_dict = OrderedDict() - for k, v in state_dict.items(): - name = k[7:] if k.startswith('module.') else k - cleaned_state_dict[name] = v - return cleaned_state_dict - - -def load_state_dict(checkpoint_path, use_ema=True): - if checkpoint_path and os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') - state_dict_key = '' - if isinstance(checkpoint, dict): - if use_ema and checkpoint.get('state_dict_ema', None) is not None: - state_dict_key = 'state_dict_ema' - elif use_ema and checkpoint.get('model_ema', None) is not None: - state_dict_key = 'model_ema' - elif 'state_dict' in checkpoint: - state_dict_key = 'state_dict' - elif 'model' in checkpoint: - state_dict_key = 'model' - state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint) - _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) - return state_dict - else: - _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) - raise FileNotFoundError() - - -def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False): - if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): - # numpy checkpoint, try to load via model specific load_pretrained fn - if hasattr(model, 'load_pretrained'): - model.load_pretrained(checkpoint_path) - else: - raise NotImplementedError('Model cannot load numpy checkpoint') - return - state_dict = load_state_dict(checkpoint_path, use_ema) - if remap: - state_dict = remap_checkpoint(model, state_dict) - incompatible_keys = model.load_state_dict(state_dict, strict=strict) - return incompatible_keys - - -def remap_checkpoint(model, state_dict, allow_reshape=True): - """ remap checkpoint by iterating over state dicts in order (ignoring original keys). - This assumes models (and originating state dict) were created with params registered in same order. - """ - out_dict = {} - for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): - assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' - if va.shape != vb.shape: - if allow_reshape: - vb = vb.reshape(va.shape) - else: - assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' - out_dict[ka] = vb - return out_dict - - -def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): - resume_epoch = None - if os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - if log_info: - _logger.info('Restoring model state from checkpoint...') - state_dict = clean_state_dict(checkpoint['state_dict']) - model.load_state_dict(state_dict) - - if optimizer is not None and 'optimizer' in checkpoint: - if log_info: - _logger.info('Restoring optimizer state from checkpoint...') - optimizer.load_state_dict(checkpoint['optimizer']) - - if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: - if log_info: - _logger.info('Restoring AMP loss scaler state from checkpoint...') - loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) - - if 'epoch' in checkpoint: - resume_epoch = checkpoint['epoch'] - if 'version' in checkpoint and checkpoint['version'] > 1: - resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save - - if log_info: - _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) - else: - model.load_state_dict(checkpoint) - if log_info: - _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) - return resume_epoch - else: - _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) - raise FileNotFoundError() - - -def _resolve_pretrained_source(pretrained_cfg): - cfg_source = pretrained_cfg.get('source', '') - pretrained_url = pretrained_cfg.get('url', None) - pretrained_file = pretrained_cfg.get('file', None) - hf_hub_id = pretrained_cfg.get('hf_hub_id', None) - # resolve where to load pretrained weights from - load_from = '' - pretrained_loc = '' - if cfg_source == 'hf-hub' and has_hf_hub(necessary=True): - # hf-hub specified as source via model identifier - load_from = 'hf-hub' - assert hf_hub_id - pretrained_loc = hf_hub_id - else: - # default source == timm or unspecified - if pretrained_file: - load_from = 'file' - pretrained_loc = pretrained_file - elif pretrained_url: - load_from = 'url' - pretrained_loc = pretrained_url - elif hf_hub_id and has_hf_hub(necessary=True): - # hf-hub available as alternate weight source in default_cfg - load_from = 'hf-hub' - pretrained_loc = hf_hub_id - if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None): - # if a filename override is set, return tuple for location w/ (hub_id, filename) - pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] - return load_from, pretrained_loc - - -def set_pretrained_download_progress(enable=True): - """ Set download progress for pretrained weights on/off (globally). """ - global _DOWNLOAD_PROGRESS - _DOWNLOAD_PROGRESS = enable - - -def set_pretrained_check_hash(enable=True): - """ Set hash checking for pretrained weights on/off (globally). """ - global _CHECK_HASH - _CHECK_HASH = enable - - -def load_custom_pretrained( - model: nn.Module, - pretrained_cfg: Optional[Dict] = None, - load_fn: Optional[Callable] = None, -): - r"""Loads a custom (read non .pth) weight file - - Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls - a passed in custom load fun, or the `load_pretrained` model member fn. - - If the object is already present in `model_dir`, it's deserialized and returned. - The default value of `model_dir` is ``/checkpoints`` where - `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. - - Args: - model: The instantiated model to load weights into - pretrained_cfg (dict): Default pretrained model cfg - load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named - 'laod_pretrained' on the model will be called if it exists - """ - pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) - if not pretrained_cfg: - _logger.warning("Invalid pretrained config, cannot load weights.") - return - - load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) - if not load_from: - _logger.warning("No pretrained weights exist for this model. Using random initialization.") - return - if load_from == 'hf-hub': # FIXME - _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") - elif load_from == 'url': - pretrained_loc = download_cached_file( - pretrained_loc, - check_hash=_CHECK_HASH, - progress=_DOWNLOAD_PROGRESS - ) - - if load_fn is not None: - load_fn(model, pretrained_loc) - elif hasattr(model, 'load_pretrained'): - model.load_pretrained(pretrained_loc) - else: - _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") - - -def adapt_input_conv(in_chans, conv_weight): - conv_type = conv_weight.dtype - conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU - O, I, J, K = conv_weight.shape - if in_chans == 1: - if I > 3: - assert conv_weight.shape[1] % 3 == 0 - # For models with space2depth stems - conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) - conv_weight = conv_weight.sum(dim=2, keepdim=False) - else: - conv_weight = conv_weight.sum(dim=1, keepdim=True) - elif in_chans != 3: - if I != 3: - raise NotImplementedError('Weight format not supported by conversion.') - else: - # NOTE this strategy should be better than random init, but there could be other combinations of - # the original RGB input layer weights that'd work better for specific cases. - repeat = int(math.ceil(in_chans / 3)) - conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] - conv_weight *= (3 / float(in_chans)) - conv_weight = conv_weight.to(conv_type) - return conv_weight - - -def load_pretrained( - model: nn.Module, - pretrained_cfg: Optional[Dict] = None, - num_classes: int = 1000, - in_chans: int = 3, - filter_fn: Optional[Callable] = None, - strict: bool = True, -): - """ Load pretrained checkpoint - - Args: - model (nn.Module) : PyTorch model module - pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset - num_classes (int): num_classes for target model - in_chans (int): in_chans for target model - filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) - strict (bool): strict load of checkpoint - - """ - pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) - if not pretrained_cfg: - _logger.warning("Invalid pretrained config, cannot load weights.") - return - - load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) - if load_from == 'file': - _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') - state_dict = load_state_dict(pretrained_loc) - elif load_from == 'url': - _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') - state_dict = load_state_dict_from_url( - pretrained_loc, - map_location='cpu', - progress=_DOWNLOAD_PROGRESS, - check_hash=_CHECK_HASH, - ) - elif load_from == 'hf-hub': - _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') - if isinstance(pretrained_loc, (list, tuple)): - state_dict = load_state_dict_from_hf(*pretrained_loc) - else: - state_dict = load_state_dict_from_hf(pretrained_loc) - else: - _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") - return - - if filter_fn is not None: - # for backwards compat with filter fn that take one arg, try one first, the two - try: - state_dict = filter_fn(state_dict) - except TypeError: - state_dict = filter_fn(state_dict, model) - - input_convs = pretrained_cfg.get('first_conv', None) - if input_convs is not None and in_chans != 3: - if isinstance(input_convs, str): - input_convs = (input_convs,) - for input_conv_name in input_convs: - weight_name = input_conv_name + '.weight' - try: - state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) - _logger.info( - f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') - except NotImplementedError as e: - del state_dict[weight_name] - strict = False - _logger.warning( - f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') - - classifiers = pretrained_cfg.get('classifier', None) - label_offset = pretrained_cfg.get('label_offset', 0) - if classifiers is not None: - if isinstance(classifiers, str): - classifiers = (classifiers,) - if num_classes != pretrained_cfg['num_classes']: - for classifier_name in classifiers: - # completely discard fully connected if model num_classes doesn't match pretrained weights - state_dict.pop(classifier_name + '.weight', None) - state_dict.pop(classifier_name + '.bias', None) - strict = False - elif label_offset > 0: - for classifier_name in classifiers: - # special case for pretrained weights with an extra background class in pretrained weights - classifier_weight = state_dict[classifier_name + '.weight'] - state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] - classifier_bias = state_dict[classifier_name + '.bias'] - state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] - - model.load_state_dict(state_dict, strict=strict) - - -def extract_layer(model, layer): - layer = layer.split('.') - module = model - if hasattr(model, 'module') and layer[0] != 'module': - module = model.module - if not hasattr(model, 'module') and layer[0] == 'module': - layer = layer[1:] - for l in layer: - if hasattr(module, l): - if not l.isdigit(): - module = getattr(module, l) - else: - module = module[int(l)] - else: - return module - return module - - -def set_layer(model, layer, val): - layer = layer.split('.') - module = model - if hasattr(model, 'module') and layer[0] != 'module': - module = model.module - lst_index = 0 - module2 = module - for l in layer: - if hasattr(module2, l): - if not l.isdigit(): - module2 = getattr(module2, l) - else: - module2 = module2[int(l)] - lst_index += 1 - lst_index -= 1 - for l in layer[:lst_index]: - if not l.isdigit(): - module = getattr(module, l) - else: - module = module[int(l)] - l = layer[lst_index] - setattr(module, l, val) - - -def adapt_model_from_string(parent_module, model_string): - separator = '***' - state_dict = {} - lst_shape = model_string.split(separator) - for k in lst_shape: - k = k.split(':') - key = k[0] - shape = k[1][1:-1].split(',') - if shape[0] != '': - state_dict[key] = [int(i) for i in shape] - - new_module = deepcopy(parent_module) - for n, m in parent_module.named_modules(): - old_module = extract_layer(parent_module, n) - if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): - if isinstance(old_module, Conv2dSame): - conv = Conv2dSame - else: - conv = nn.Conv2d - s = state_dict[n + '.weight'] - in_channels = s[1] - out_channels = s[0] - g = 1 - if old_module.groups > 1: - in_channels = out_channels - g = in_channels - new_conv = conv( - in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, - bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, - groups=g, stride=old_module.stride) - set_layer(new_module, n, new_conv) - elif isinstance(old_module, BatchNormAct2d): - new_bn = BatchNormAct2d( - state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, - affine=old_module.affine, track_running_stats=True) - new_bn.drop = old_module.drop - new_bn.act = old_module.act - set_layer(new_module, n, new_bn) - elif isinstance(old_module, nn.BatchNorm2d): - new_bn = nn.BatchNorm2d( - num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, - affine=old_module.affine, track_running_stats=True) - set_layer(new_module, n, new_bn) - elif isinstance(old_module, nn.Linear): - # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? - num_features = state_dict[n + '.weight'][1] - new_fc = Linear( - in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) - set_layer(new_module, n, new_fc) - if hasattr(new_module, 'num_features'): - new_module.num_features = num_features - new_module.eval() - parent_module.eval() - - return new_module - - -def adapt_model_from_file(parent_module, model_variant): - adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') - with open(adapt_file, 'r') as f: - return adapt_model_from_string(parent_module, f.read().strip()) - - -def pretrained_cfg_for_features(pretrained_cfg): - pretrained_cfg = deepcopy(pretrained_cfg) - # remove default pretrained cfg fields that don't have much relevance for feature backbone - to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size? - for tr in to_remove: - pretrained_cfg.pop(tr, None) - return pretrained_cfg - - -def _filter_kwargs(kwargs, names): - if not kwargs or not names: - return - for n in names: - kwargs.pop(n, None) - - -def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter): - """ Update the default_cfg and kwargs before passing to model - - Args: - pretrained_cfg: input pretrained cfg (updated in-place) - kwargs: keyword args passed to model build fn (updated in-place) - kwargs_filter: keyword arg keys that must be removed before model __init__ - """ - # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) - default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') - if pretrained_cfg.get('fixed_input_size', False): - # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size - default_kwarg_names += ('img_size',) - - for n in default_kwarg_names: - # for legacy reasons, model __init__args uses img_size + in_chans as separate args while - # pretrained_cfg has one input_size=(C, H ,W) entry - if n == 'img_size': - input_size = pretrained_cfg.get('input_size', None) - if input_size is not None: - assert len(input_size) == 3 - kwargs.setdefault(n, input_size[-2:]) - elif n == 'in_chans': - input_size = pretrained_cfg.get('input_size', None) - if input_size is not None: - assert len(input_size) == 3 - kwargs.setdefault(n, input_size[0]) - else: - default_val = pretrained_cfg.get(n, None) - if default_val is not None: - kwargs.setdefault(n, pretrained_cfg[n]) - - # Filter keyword args for task specific model variants (some 'features only' models, etc.) - _filter_kwargs(kwargs, names=kwargs_filter) - - -def resolve_pretrained_cfg( - variant: str, - pretrained_cfg=None, - pretrained_cfg_overlay=None, -) -> PretrainedCfg: - model_with_tag = variant - pretrained_tag = None - if pretrained_cfg: - if isinstance(pretrained_cfg, dict): - # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg - pretrained_cfg = PretrainedCfg(**pretrained_cfg) - elif isinstance(pretrained_cfg, str): - pretrained_tag = pretrained_cfg - pretrained_cfg = None - - # fallback to looking up pretrained cfg in model registry by variant identifier - if not pretrained_cfg: - if pretrained_tag: - model_with_tag = '.'.join([variant, pretrained_tag]) - pretrained_cfg = get_pretrained_cfg(model_with_tag) - - if not pretrained_cfg: - _logger.warning( - f"No pretrained configuration specified for {model_with_tag} model. Using a default." - f" Please add a config to the model pretrained_cfg registry or pass explicitly.") - pretrained_cfg = PretrainedCfg() # instance with defaults - - pretrained_cfg_overlay = pretrained_cfg_overlay or {} - if not pretrained_cfg.architecture: - pretrained_cfg_overlay.setdefault('architecture', variant) - pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay) - - return pretrained_cfg - - -def build_model_with_cfg( - model_cls: Callable, - variant: str, - pretrained: bool, - pretrained_cfg: Optional[Dict] = None, - pretrained_cfg_overlay: Optional[Dict] = None, - model_cfg: Optional[Any] = None, - feature_cfg: Optional[Dict] = None, - pretrained_strict: bool = True, - pretrained_filter_fn: Optional[Callable] = None, - kwargs_filter: Optional[Tuple[str]] = None, - **kwargs, -): - """ Build model with specified default_cfg and optional model_cfg - - This helper fn aids in the construction of a model including: - * handling default_cfg and associated pretrained weight loading - * passing through optional model_cfg for models with config based arch spec - * features_only model adaptation - * pruning config / model adaptation - - Args: - model_cls (nn.Module): model class - variant (str): model variant name - pretrained (bool): load pretrained weights - pretrained_cfg (dict): model's pretrained weight/task config - model_cfg (Optional[Dict]): model's architecture config - feature_cfg (Optional[Dict]: feature extraction adapter config - pretrained_strict (bool): load pretrained weights strictly - pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights - kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model - **kwargs: model args passed through to model __init__ - """ - pruned = kwargs.pop('pruned', False) - features = False - feature_cfg = feature_cfg or {} - - # resolve and update model pretrained config and model kwargs - pretrained_cfg = resolve_pretrained_cfg( - variant, - pretrained_cfg=pretrained_cfg, - pretrained_cfg_overlay=pretrained_cfg_overlay - ) - - # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model - pretrained_cfg = pretrained_cfg.to_dict() - - _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) - - # Setup for feature extraction wrapper done at end of this fn - if kwargs.pop('features_only', False): - features = True - feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) - if 'out_indices' in kwargs: - feature_cfg['out_indices'] = kwargs.pop('out_indices') - - # Instantiate the model - if model_cfg is None: - model = model_cls(**kwargs) - else: - model = model_cls(cfg=model_cfg, **kwargs) - model.pretrained_cfg = pretrained_cfg - model.default_cfg = model.pretrained_cfg # alias for backwards compat - - if pruned: - model = adapt_model_from_file(model, variant) - - # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats - num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) - if pretrained: - if pretrained_cfg.get('custom_load', False): - load_custom_pretrained( - model, - pretrained_cfg=pretrained_cfg, - ) - else: - load_pretrained( - model, - pretrained_cfg=pretrained_cfg, - num_classes=num_classes_pretrained, - in_chans=kwargs.get('in_chans', 3), - filter_fn=pretrained_filter_fn, - strict=pretrained_strict, - ) - - # Wrap the model in a feature extraction module if enabled - if features: - feature_cls = FeatureListNet - if 'feature_cls' in feature_cfg: - feature_cls = feature_cfg.pop('feature_cls') - if isinstance(feature_cls, str): - feature_cls = feature_cls.lower() - if 'hook' in feature_cls: - feature_cls = FeatureHookNet - elif feature_cls == 'fx': - feature_cls = FeatureGraphNet - else: - assert False, f'Unknown feature class {feature_cls}' - model = feature_cls(model, **feature_cfg) - model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg - model.default_cfg = model.pretrained_cfg # alias for backwards compat - - return model - - -def model_parameters(model, exclude_head=False): - if exclude_head: - # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering - return [p for p in model.parameters()][:-2] - else: - return model.parameters() - - -def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: - if not depth_first and include_root: - fn(module=module, name=name) - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - fn(module=module, name=name) - return module - - -def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): - if not depth_first and include_root: - yield name, module - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - yield from named_modules( - module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - yield name, module - - -def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False): - if module._parameters and not depth_first and include_root: - yield name, module - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - yield from named_modules_with_params( - module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if module._parameters and depth_first and include_root: - yield name, module - - -MATCH_PREV_GROUP = (99999,) - - -def group_with_matcher( - named_objects, - group_matcher: Union[Dict, Callable], - output_values: bool = False, - reverse: bool = False -): - if isinstance(group_matcher, dict): - # dictionary matcher contains a dict of raw-string regex expr that must be compiled - compiled = [] - for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): - if mspec is None: - continue - # map all matching specifications into 3-tuple (compiled re, prefix, suffix) - if isinstance(mspec, (tuple, list)): - # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) - for sspec in mspec: - compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] - else: - compiled += [(re.compile(mspec), (group_ordinal,), None)] - group_matcher = compiled - - def _get_grouping(name): - if isinstance(group_matcher, (list, tuple)): - for match_fn, prefix, suffix in group_matcher: - r = match_fn.match(name) - if r: - parts = (prefix, r.groups(), suffix) - # map all tuple elem to int for numeric sort, filter out None entries - return tuple(map(float, chain.from_iterable(filter(None, parts)))) - return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal - else: - ord = group_matcher(name) - if not isinstance(ord, collections.abc.Iterable): - return ord, - return tuple(ord) - - # map layers into groups via ordinals (ints or tuples of ints) from matcher - grouping = defaultdict(list) - for k, v in named_objects: - grouping[_get_grouping(k)].append(v if output_values else k) - - # remap to integers - layer_id_to_param = defaultdict(list) - lid = -1 - for k in sorted(filter(lambda x: x is not None, grouping.keys())): - if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: - lid += 1 - layer_id_to_param[lid].extend(grouping[k]) - - if reverse: - assert not output_values, "reverse mapping only sensible for name output" - # output reverse mapping - param_to_layer_id = {} - for lid, lm in layer_id_to_param.items(): - for n in lm: - param_to_layer_id[n] = lid - return param_to_layer_id - - return layer_id_to_param - - -def group_parameters( - module: nn.Module, - group_matcher, - output_values=False, - reverse=False, -): - return group_with_matcher( - module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) - - -def group_modules( - module: nn.Module, - group_matcher, - output_values=False, - reverse=False, -): - return group_with_matcher( - named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) - - -def checkpoint_seq( - functions, - x, - every=1, - flatten=False, - skip_last=False, - preserve_rng_state=True -): - r"""A helper function for checkpointing sequential models. - - Sequential models execute a list of modules/functions in order - (sequentially). Therefore, we can divide such a sequence into segments - and checkpoint each segment. All segments except run in :func:`torch.no_grad` - manner, i.e., not storing the intermediate activations. The inputs of each - checkpointed segment will be saved for re-running the segment in the backward pass. - - See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. - - .. warning:: - Checkpointing currently only supports :func:`torch.autograd.backward` - and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` - is not supported. - - .. warning: - At least one of the inputs needs to have :code:`requires_grad=True` if - grads are needed for model inputs, otherwise the checkpointed part of the - model won't have gradients. - - Args: - functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. - x: A Tensor that is input to :attr:`functions` - every: checkpoint every-n functions (default: 1) - flatten (bool): flatten nn.Sequential of nn.Sequentials - skip_last (bool): skip checkpointing the last function in the sequence if True - preserve_rng_state (bool, optional, default=True): Omit stashing and restoring - the RNG state during each checkpoint. - - Returns: - Output of running :attr:`functions` sequentially on :attr:`*inputs` - - Example: - >>> model = nn.Sequential(...) - >>> input_var = checkpoint_seq(model, input_var, every=2) - """ - def run_function(start, end, functions): - def forward(_x): - for j in range(start, end + 1): - _x = functions[j](_x) - return _x - return forward - - if isinstance(functions, torch.nn.Sequential): - functions = functions.children() - if flatten: - functions = chain.from_iterable(functions) - if not isinstance(functions, (tuple, list)): - functions = tuple(functions) - - num_checkpointed = len(functions) - if skip_last: - num_checkpointed -= 1 - end = -1 - for start in range(0, num_checkpointed, every): - end = min(start + every - 1, num_checkpointed - 1) - x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) - if skip_last: - return run_function(end + 1, len(functions) - 1, functions)(x) - return x - - -def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'): - prefix_is_tuple = isinstance(prefix, tuple) - if isinstance(module_types, str): - if module_types == 'container': - module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict) - else: - module_types = (nn.Sequential,) - for name, module in named_modules: - if depth and isinstance(module, module_types): - yield from flatten_modules( - module.named_children(), - depth - 1, - prefix=(name,) if prefix_is_tuple else name, - module_types=module_types, - ) - else: - if prefix_is_tuple: - name = prefix + (name,) - yield name, module - else: - if prefix: - name = '.'.join([prefix, name]) - yield name, module +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 30860120..338d409e 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -16,12 +16,14 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .features import FeatureInfo -from .helpers import build_model_with_cfg, pretrained_cfg_for_features -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._features import FeatureInfo +from ._registry import register_model from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE +__all__ = ['HighResolutionNet', 'HighResolutionNetFeatures'] # model_registry will add each entrypoint fn to this + _BN_MOMENTUM = 0.1 _logger = logging.getLogger(__name__) diff --git a/timm/models/hub.py b/timm/models/hub.py index 18c5444a..074ca025 100644 --- a/timm/models/hub.py +++ b/timm/models/hub.py @@ -1,217 +1,4 @@ -import json -import logging -import os -from functools import partial -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Optional, Union +from _hub import * -import torch -from torch.hub import HASH_REGEX, download_url_to_file, urlparse - -try: - from torch.hub import get_dir -except ImportError: - from torch.hub import _get_torch_home as get_dir - -from timm import __version__ -from timm.models.pretrained import filter_pretrained_cfg - -try: - from huggingface_hub import ( - create_repo, get_hf_file_metadata, - hf_hub_download, hf_hub_url, - repo_type_and_id_from_hf_id, upload_folder) - from huggingface_hub.utils import EntryNotFoundError - hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__) - _has_hf_hub = True -except ImportError: - hf_hub_download = None - _has_hf_hub = False - -_logger = logging.getLogger(__name__) - - -def get_cache_dir(child_dir=''): - """ - Returns the location of the directory where models are cached (and creates it if necessary). - """ - # Issue warning to move data if old env is set - if os.getenv('TORCH_MODEL_ZOO'): - _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') - - hub_dir = get_dir() - child_dir = () if not child_dir else (child_dir,) - model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) - os.makedirs(model_dir, exist_ok=True) - return model_dir - - -def download_cached_file(url, check_hash=True, progress=False): - if isinstance(url, (list, tuple)): - url, filename = url - else: - parts = urlparse(url) - filename = os.path.basename(parts.path) - cached_file = os.path.join(get_cache_dir(), filename) - if not os.path.exists(cached_file): - _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) - hash_prefix = None - if check_hash: - r = HASH_REGEX.search(filename) # r is Optional[Match[str]] - hash_prefix = r.group(1) if r else None - download_url_to_file(url, cached_file, hash_prefix, progress=progress) - return cached_file - - -def has_hf_hub(necessary=False): - if not _has_hf_hub and necessary: - # if no HF Hub module installed, and it is necessary to continue, raise error - raise RuntimeError( - 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') - return _has_hf_hub - - -def hf_split(hf_id): - # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme - rev_split = hf_id.split('@') - assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' - hf_model_id = rev_split[0] - hf_revision = rev_split[-1] if len(rev_split) > 1 else None - return hf_model_id, hf_revision - - -def load_cfg_from_json(json_file: Union[str, os.PathLike]): - with open(json_file, "r", encoding="utf-8") as reader: - text = reader.read() - return json.loads(text) - - -def _download_from_hf(model_id: str, filename: str): - hf_model_id, hf_revision = hf_split(model_id) - return hf_hub_download(hf_model_id, filename, revision=hf_revision) - - -def load_model_config_from_hf(model_id: str): - assert has_hf_hub(True) - cached_file = _download_from_hf(model_id, 'config.json') - - hf_config = load_cfg_from_json(cached_file) - if 'pretrained_cfg' not in hf_config: - # old form, pull pretrain_cfg out of the base dict - pretrained_cfg = hf_config - hf_config = {} - hf_config['architecture'] = pretrained_cfg.pop('architecture') - hf_config['num_features'] = pretrained_cfg.pop('num_features', None) - if 'labels' in pretrained_cfg: - hf_config['label_name'] = pretrained_cfg.pop('labels') - hf_config['pretrained_cfg'] = pretrained_cfg - - # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now - pretrained_cfg = hf_config['pretrained_cfg'] - pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation - pretrained_cfg['source'] = 'hf-hub' - if 'num_classes' in hf_config: - # model should be created with parent num_classes if they exist - pretrained_cfg['num_classes'] = hf_config['num_classes'] - model_name = hf_config['architecture'] - - return pretrained_cfg, model_name - - -def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): - assert has_hf_hub(True) - cached_file = _download_from_hf(model_id, filename) - state_dict = torch.load(cached_file, map_location='cpu') - return state_dict - - -def save_for_hf(model, save_directory, model_config=None): - assert has_hf_hub(True) - model_config = model_config or {} - save_directory = Path(save_directory) - save_directory.mkdir(exist_ok=True, parents=True) - - weights_path = save_directory / 'pytorch_model.bin' - torch.save(model.state_dict(), weights_path) - - config_path = save_directory / 'config.json' - hf_config = {} - pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) - # set some values at root config level - hf_config['architecture'] = pretrained_cfg.pop('architecture') - hf_config['num_classes'] = model_config.get('num_classes', model.num_classes) - hf_config['num_features'] = model_config.get('num_features', model.num_features) - hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None)) - - if 'label' in model_config: - _logger.warning( - "'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. " - "Using provided 'label' field as 'label_name'.") - model_config['label_name'] = model_config.pop('label') - - label_name = model_config.pop('label_name', None) - if label_name: - assert isinstance(label_name, (dict, list, tuple)) - # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages) - # can be a dict id: name if there are id gaps, or tuple/list if no gaps. - hf_config['label_name'] = model_config['label_name'] - - display_name = model_config.pop('display_name', None) - if display_name: - assert isinstance(display_name, dict) - # map label_name -> user interface display name - hf_config['display_name'] = model_config['display_name'] - - hf_config['pretrained_cfg'] = pretrained_cfg - hf_config.update(model_config) - - with config_path.open('w') as f: - json.dump(hf_config, f, indent=2) - - -def push_to_hf_hub( - model, - repo_id: str, - commit_message: str = 'Add model', - token: Optional[str] = None, - revision: Optional[str] = None, - private: bool = False, - create_pr: bool = False, - model_config: Optional[dict] = None, -): - # Create repo if it doesn't exist yet - repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) - - # Infer complete repo_id from repo_url - # Can be different from the input `repo_id` if repo_owner was implicit - _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) - repo_id = f"{repo_owner}/{repo_name}" - - # Check if README file already exist in repo - try: - get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) - has_readme = True - except EntryNotFoundError: - has_readme = False - - # Dump model and push to Hub - with TemporaryDirectory() as tmpdir: - # Save model weights and config. - save_for_hf(model, tmpdir, model_config=model_config) - - # Add readme if it does not exist - if not has_readme: - model_name = repo_id.split('/')[-1] - readme_path = Path(tmpdir) / "README.md" - readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}' - readme_path.write_text(readme_text) - - # Upload model and return - return upload_folder( - repo_id=repo_id, - folder_path=tmpdir, - revision=revision, - create_pr=create_pr, - commit_message=commit_message, - ) +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index fa7b8ec8..3006f3d2 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -7,9 +7,10 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, flatten_modules -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import flatten_modules +from ._registry import register_model __all__ = ['InceptionResnetV2'] diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index c70bd608..28794ce6 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -8,9 +8,13 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, flatten_modules -from .registry import register_model -from .layers import trunc_normal_, create_classifier, Linear +from timm.layers import trunc_normal_, create_classifier, Linear +from ._builder import build_model_with_cfg +from ._builder import resolve_pretrained_cfg +from ._manipulate import flatten_modules +from ._registry import register_model + +__all__ = ['InceptionV3', 'InceptionV3Aux'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 5f4e208f..c1559829 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -7,9 +7,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['InceptionV4'] diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 21c641b6..97e70563 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1,44 +1,48 @@ -from .activations import * -from .adaptive_avgmax_pool import \ +# NOTE timm.models.layers is DEPRECATED, please use timm.layers, this is here to reduce breakages in transition +from timm.layers.activations import * +from timm.layers.adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d -from .blur_pool import BlurPool2d -from .classifier import ClassifierHead, create_classifier -from .cond_conv2d import CondConv2d, get_condconv_initializer -from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ +from timm.layers.blur_pool import BlurPool2d +from timm.layers.classifier import ClassifierHead, create_classifier +from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer +from timm.layers.config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ set_layer_config -from .conv2d_same import Conv2dSame, conv2d_same -from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct -from .create_act import create_act_layer, get_act_layer, get_act_fn -from .create_attn import get_attn, create_attn -from .create_conv2d import create_conv2d -from .create_norm import get_norm_layer, create_norm_layer -from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer -from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path -from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn -from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ +from timm.layers.conv2d_same import Conv2dSame, conv2d_same +from timm.layers.conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct +from timm.layers.create_act import create_act_layer, get_act_layer, get_act_fn +from timm.layers.create_attn import get_attn, create_attn +from timm.layers.create_conv2d import create_conv2d +from timm.layers.create_norm import get_norm_layer, create_norm_layer +from timm.layers.create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer +from timm.layers.drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from timm.layers.eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from timm.layers.evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a -from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm -from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d -from .gather_excite import GatherExcite -from .global_context import GlobalContext -from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple -from .inplace_abn import InplaceAbn -from .linear import Linear -from .mixed_conv2d import MixedConv2d -from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp -from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d -from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm -from .padding import get_padding, get_same_padding, pad_same -from .patch_embed import PatchEmbed -from .pool2d_same import AvgPool2dSame, create_pool2d -from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite -from .selective_kernel import SelectiveKernel -from .separable_conv import SeparableConv2d, SeparableConvNormAct -from .space_to_depth import SpaceToDepthModule -from .split_attn import SplitAttn -from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model -from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame -from .test_time_pool import TestTimePoolHead, apply_test_time_pool -from .trace_utils import _assert, _float_to_int -from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ +from timm.layers.fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm +from timm.layers.filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d +from timm.layers.gather_excite import GatherExcite +from timm.layers.global_context import GlobalContext +from timm.layers.helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple +from timm.layers.inplace_abn import InplaceAbn +from timm.layers.linear import Linear +from timm.layers.mixed_conv2d import MixedConv2d +from timm.layers.mlp import Mlp, GluMlp, GatedMlp, ConvMlp +from timm.layers.non_local_attn import NonLocalAttn, BatNonLocalAttn +from timm.layers.norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d +from timm.layers.norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm +from timm.layers.padding import get_padding, get_same_padding, pad_same +from timm.layers.patch_embed import PatchEmbed +from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d +from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from timm.layers.selective_kernel import SelectiveKernel +from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct +from timm.layers.space_to_depth import SpaceToDepthModule +from timm.layers.split_attn import SplitAttn +from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool +from timm.layers.trace_utils import _assert, _float_to_int +from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/levit.py b/timm/models/levit.py index cea9f0fc..8dc11309 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -23,8 +23,6 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Modified from # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Copyright 2020 Ross Wightman, Apache-2.0 License -import itertools -from copy import deepcopy from functools import partial from typing import Dict @@ -32,10 +30,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import to_ntuple, get_act_layer -from .vision_transformer import trunc_normal_ -from .registry import register_model +from timm.layers import to_ntuple, get_act_layer, trunc_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['LevitDistilled'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 3f315093..1e2666e5 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -45,17 +45,17 @@ from typing import Callable, Optional, Union, Tuple, List import torch from torch import nn -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq, named_apply -from .fx_features import register_notrace_function -from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm -from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d -from .layers import SelectAdaptivePool2d, create_pool2d -from .layers import to_2tuple, extend_tuple, make_divisible, _assert -from .pretrained import generate_default_cfgs -from .registry import register_model +from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm +from timm.layers import SelectAdaptivePool2d, create_pool2d +from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d +from timm.layers import to_2tuple, extend_tuple, make_divisible, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply, checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location __all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index a77e2eb7..a7825899 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -39,16 +39,18 @@ A thank you to paper authors for releasing code and weights. Hacked together by / Copyright 2021 Ross Wightman """ import math -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model + +__all__ = ['MixerBlock'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index bb72ccb8..cf4f268d 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -14,13 +14,14 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ +from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq -from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer -from .registry import register_model +from ._features import FeatureInfo, FeatureHooks +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['MobileNetV3', 'MobileNetV3Features'] diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index bd5479a7..3d2ae84a 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -14,18 +14,18 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022 # Copyright (C) 2020 Apple Inc. All Rights Reserved. # import math -from typing import Union, Callable, Dict, Tuple, Optional, Sequence +from typing import Callable, Tuple, Optional import torch -from torch import nn import torch.nn.functional as F +from torch import nn +from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups -from .fx_features import register_notrace_module -from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath from .vision_transformer import Block as TransformerBlock -from .helpers import build_model_with_cfg -from .registry import register_model __all__ = [] diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index c5aaa09e..5c0a6650 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -24,10 +24,12 @@ import torch.utils.checkpoint as checkpoint from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg -from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple -from .registry import register_model +from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model + +__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 50db1a3d..0b2178d6 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier -from .registry import register_model +from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['NASNetALarge'] diff --git a/timm/models/nest.py b/timm/models/nest.py index 8692a2b1..c9c6258c 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -25,12 +25,14 @@ import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ -from .layers import _assert -from .layers import create_conv2d, create_pool2d, to_ntuple -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert +from timm.layers import create_conv2d, create_pool2d, to_ntuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model + +__all__ = ['Nest'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 3a45410b..48f91b35 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -16,21 +16,23 @@ Status: Hacked together by / copyright Ross Wightman, 2021. """ -import math -from dataclasses import dataclass, field from collections import OrderedDict -from typing import Tuple, Optional +from dataclasses import dataclass from functools import partial +from typing import Tuple, Optional import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model -from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ +from timm.layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame, \ get_act_layer, get_act_fn, get_attn, make_divisible +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['NormFreeNet', 'NfCfg'] # model_registry will add each entrypoint fn to this def _dcfg(url='', **kwargs): diff --git a/timm/models/pit.py b/timm/models/pit.py index 0f571319..4f40e5e0 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -13,7 +13,6 @@ Modifications for timm by / Copyright 2020 Ross Wightman import math import re -from copy import deepcopy from functools import partial from typing import Tuple @@ -21,12 +20,15 @@ import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import trunc_normal_, to_2tuple -from .registry import register_model +from timm.layers import trunc_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._registry import register_model from .vision_transformer import Block +__all__ = ['PoolingVisionTransformer'] # model_registry will add each entrypoint fn to this + + def _cfg(url='', **kwargs): return { 'url': url, diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 81067845..7291c8fb 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -12,9 +12,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier -from .registry import register_model +from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['PNASNet5Large'] diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index 09359bc8..b4d2d18f 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -19,15 +19,15 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import copy import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['PoolFormer'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index dd3cf690..696a2506 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -24,9 +24,9 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ -from .registry import register_model +from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['PyramidVisionTransformerV2'] diff --git a/timm/models/registry.py b/timm/models/registry.py index 159ffb5f..58e2e1f4 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -1,210 +1,4 @@ -""" Model Registry -Hacked together by / Copyright 2020 Ross Wightman -""" +from ._registry import * -import fnmatch -import re -import sys -from collections import defaultdict, deque -from copy import deepcopy -from typing import List, Optional, Union, Tuple - -from .pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag - -__all__ = [ - 'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', - 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] - -_module_to_models = defaultdict(set) # dict of sets to check membership of model in module -_model_to_module = {} # mapping of model names to module names -_model_entrypoints = {} # mapping of model names to architecture entrypoint fns -_model_has_pretrained = set() # set of model names that have pretrained weight url present -_model_default_cfgs = dict() # central repo for model arch -> default cfg objects -_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs -_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names - - -def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]: - return split_model_name_tag(model_name)[0] - - -def register_model(fn): - # lookup containing module - mod = sys.modules[fn.__module__] - module_name_split = fn.__module__.split('.') - module_name = module_name_split[-1] if len(module_name_split) else '' - - # add model to __all__ in module - model_name = fn.__name__ - if hasattr(mod, '__all__'): - mod.__all__.append(model_name) - else: - mod.__all__ = [model_name] - - # add entries to registry dict/sets - _model_entrypoints[model_name] = fn - _model_to_module[model_name] = module_name - _module_to_models[module_name].add(model_name) - if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: - # this will catch all models that have entrypoint matching cfg key, but miss any aliasing - # entrypoints or non-matching combos - cfg = mod.default_cfgs[model_name] - if not isinstance(cfg, DefaultCfg): - # new style default cfg dataclass w/ multiple entries per model-arch - assert isinstance(cfg, dict) - # old style cfg dict per model-arch - cfg = PretrainedCfg(**cfg) - cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg}) - - for tag_idx, tag in enumerate(cfg.tags): - is_default = tag_idx == 0 - pretrained_cfg = cfg.cfgs[tag] - if is_default: - _model_pretrained_cfgs[model_name] = pretrained_cfg - if pretrained_cfg.has_weights: - # add tagless entry if it's default and has weights - _model_has_pretrained.add(model_name) - if tag: - model_name_tag = '.'.join([model_name, tag]) - _model_pretrained_cfgs[model_name_tag] = pretrained_cfg - if pretrained_cfg.has_weights: - # add model w/ tag if tag is valid - _model_has_pretrained.add(model_name_tag) - _model_with_tags[model_name].append(model_name_tag) - else: - _model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances) - - _model_default_cfgs[model_name] = cfg - - return fn - - -def _natural_key(string_): - return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] - - -def list_models( - filter: Union[str, List[str]] = '', - module: str = '', - pretrained=False, - exclude_filters: str = '', - name_matches_cfg: bool = False, - include_tags: Optional[bool] = None, -): - """ Return list of available model names, sorted alphabetically - - Args: - filter (str) - Wildcard filter string that works with fnmatch - module (str) - Limit model selection to a specific submodule (ie 'vision_transformer') - pretrained (bool) - Include only models with valid pretrained weights if True - exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter - name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) - include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults - set to True when pretrained=True else False (default: None) - Example: - model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' - model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module - """ - if include_tags is None: - # FIXME should this be default behaviour? or default to include_tags=True? - include_tags = pretrained - - if module: - all_models = list(_module_to_models[module]) - else: - all_models = _model_entrypoints.keys() - - if include_tags: - # expand model names to include names w/ pretrained tags - models_with_tags = [] - for m in all_models: - models_with_tags.extend(_model_with_tags[m]) - all_models = models_with_tags - - if filter: - models = [] - include_filters = filter if isinstance(filter, (tuple, list)) else [filter] - for f in include_filters: - include_models = fnmatch.filter(all_models, f) # include these models - if len(include_models): - models = set(models).union(include_models) - else: - models = all_models - - if exclude_filters: - if not isinstance(exclude_filters, (tuple, list)): - exclude_filters = [exclude_filters] - for xf in exclude_filters: - exclude_models = fnmatch.filter(models, xf) # exclude these models - if len(exclude_models): - models = set(models).difference(exclude_models) - - if pretrained: - models = _model_has_pretrained.intersection(models) - - if name_matches_cfg: - models = set(_model_pretrained_cfgs).intersection(models) - - return list(sorted(models, key=_natural_key)) - - -def list_pretrained( - filter: Union[str, List[str]] = '', - exclude_filters: str = '', -): - return list_models( - filter=filter, - pretrained=True, - exclude_filters=exclude_filters, - include_tags=True, - ) - - -def is_model(model_name): - """ Check if a model name exists - """ - arch_name = get_arch_name(model_name) - return arch_name in _model_entrypoints - - -def model_entrypoint(model_name): - """Fetch a model entrypoint for specified model name - """ - arch_name = get_arch_name(model_name) - return _model_entrypoints[arch_name] - - -def list_modules(): - """ Return list of module names that contain models / model entrypoints - """ - modules = _module_to_models.keys() - return list(sorted(modules)) - - -def is_model_in_modules(model_name, module_names): - """Check if a model exists within a subset of modules - Args: - model_name (str) - name of model to check - module_names (tuple, list, set) - names of modules to search in - """ - arch_name = get_arch_name(model_name) - assert isinstance(module_names, (tuple, list, set)) - return any(arch_name in _module_to_models[n] for n in module_names) - - -def is_model_pretrained(model_name): - return model_name in _model_has_pretrained - - -def get_pretrained_cfg(model_name): - if model_name in _model_pretrained_cfgs: - return deepcopy(_model_pretrained_cfgs[model_name]) - raise RuntimeError(f'No pretrained config exists for model {model_name}.') - - -def get_pretrained_cfg_value(model_name, cfg_key): - """ Get a specific model default_cfg value by key. None if key doesn't exist. - """ - if model_name in _model_pretrained_cfgs: - return getattr(_model_pretrained_cfgs[model_name], cfg_key, None) - raise RuntimeError(f'No pretrained config exist for model {model_name}.') \ No newline at end of file +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 0ad7c826..e1cc821b 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -23,10 +23,13 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct -from .layers import get_act_layer, get_norm_act_layer, create_conv2d -from .registry import register_model +from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct +from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model + +__all__ = ['RegNet', 'RegNetCfg'] # model_registry will add each entrypoint fn to this @dataclass diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 6c2fd1bf..4724df2a 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -8,8 +8,8 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .registry import register_model +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet __all__ = [] diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 735b91a2..3b001c7b 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -6,13 +6,12 @@ Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang198 Modified for torchscript compat, and consistency with timm by Ross Wightman """ -import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SplitAttn -from .registry import register_model +from timm.layers import SplitAttn +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet diff --git a/timm/models/resnet.py b/timm/models/resnet.py index d0d98894..50849017 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -15,9 +15,11 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier -from .registry import register_model +from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \ + create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, model_entrypoint __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this @@ -675,6 +677,11 @@ class ResNet(nn.Module): self.init_weights(zero_init_last=zero_init_last) + @staticmethod + def from_pretrained(model_name: str, load_weights=True, **kwargs) -> 'ResNet': + entry_fn = model_entrypoint(model_name, 'resnet') + return entry_fn(pretrained=not load_weights, **kwargs) + @torch.jit.ignore def init_weights(self, zero_init_last=True): for n, m in self.named_modules(): @@ -822,7 +829,7 @@ def resnet50(pretrained=False, **kwargs): @register_model -def resnet50d(pretrained=False, **kwargs): +def resnet50d(pretrained=False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model. """ model_args = dict( diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index b21ef7f5..f8c4298b 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -30,16 +30,19 @@ Original copyright of Google code below, modifications by Ross Wightman, Copyrig # limitations under the License. from collections import OrderedDict # pylint: disable=g-importing-member +from functools import partial import torch import torch.nn as nn -from functools import partial from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq -from .registry import register_model -from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, EvoNorm2dS1, FilterResponseNormTlu2d,\ +from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \ ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv +from ._registry import register_model + +__all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 33e97222..51e8cdc2 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -10,16 +10,20 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe Copyright 2020 Ross Wightman """ -import torch -import torch.nn as nn from functools import partial from math import ceil +import torch +import torch.nn as nn + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule -from .registry import register_model -from .efficientnet_builder import efficientnet_init_weights +from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule +from ._builder import build_model_with_cfg +from ._efficientnet_builder import efficientnet_init_weights +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['ReXNetV1'] # model_registry will add each entrypoint fn to this def _cfg(url=''): diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 1a9ac929..4d40c49a 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -16,9 +16,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/senet.py b/timm/models/senet.py index a9e23ff1..d36e9854 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -19,9 +19,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['SENet'] diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index b1ae92a4..f3f758b9 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -6,7 +6,6 @@ Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2 # Copyright (c) 2022. Yuki Tatsunami # Licensed under the Apache License, Version 2.0 (the "License"); - import math from functools import partial from typing import Tuple @@ -15,9 +14,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT -from .helpers import build_model_with_cfg, named_apply -from .layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed -from .registry import register_model +from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed +from ._builder import build_model_with_cfg +from ._manipulate import named_apply +from ._registry import register_model + +__all__ = ['Sequencer2D'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/sknet.py b/timm/models/sknet.py index fb9f063a..5a29b9a4 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -13,9 +13,9 @@ import math from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SelectiveKernel, ConvNormAct, ConvNormActAa, create_attn -from .registry import register_model +from timm.layers import SelectiveKernel, ConvNormAct, create_attn +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index f2305fb2..5df06d4d 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -17,19 +17,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # -------------------------------------------------------- import logging import math -from functools import partial from typing import Optional import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit +__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 0c9db3dd..efaaa9e9 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -21,10 +21,12 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model + +__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index d143c14c..cf10b39c 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -29,7 +29,6 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W # -------------------------------------------------------- import logging import math -from copy import deepcopy from typing import Tuple, Optional, List, Union, Any, Type import torch @@ -38,11 +37,13 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import DropPath, Mlp, to_2tuple, _assert -from .registry import register_model +from timm.layers import DropPath, Mlp, to_2tuple, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply +from ._registry import register_model +__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 5b72b196..50088baf 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -7,17 +7,18 @@ The official mindspore code is released and available at https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT """ import math + import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import build_model_with_cfg -from timm.models.layers import Mlp, DropPath, trunc_normal_ -from timm.models.layers.helpers import to_2tuple -from timm.models.layers import _assert -from timm.models.registry import register_model -from timm.models.vision_transformer import resize_pos_embed +from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple +from ._builder import build_model_with_cfg +from ._registry import register_model +from .vision_transformer import resize_pos_embed + +__all__ = ['TNT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 2469acd2..83cb0576 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -10,11 +10,11 @@ from collections import OrderedDict import torch import torch.nn as nn -from .helpers import build_model_with_cfg -from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule -from .registry import register_model +from timm.layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule +from ._builder import build_model_with_cfg +from ._registry import register_model -__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] +__all__ = ['TResNet'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/twins.py b/timm/models/twins.py index 0626db37..41944c36 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -12,20 +12,21 @@ Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/li # Written by Xinjie Li, Xiangxiang Chu # -------------------------------------------------------- import math -from copy import deepcopy -from typing import Optional, Tuple +from functools import partial +from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F -from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ -from .fx_features import register_notrace_module -from .registry import register_model +from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .vision_transformer import Attention -from .helpers import build_model_with_cfg + +__all__ = ['Twins'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/vgg.py b/timm/models/vgg.py index caf96517..abe9f8d5 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -5,21 +5,19 @@ timm functionality. Copyright 2021 Ross Wightman """ +from typing import Union, List, Dict, Any, cast + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .fx_features import register_notrace_module -from .layers import ClassifierHead -from .registry import register_model - -__all__ = [ - 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', - 'vgg19_bn', 'vgg19', -] +from timm.layers import ClassifierHead +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model + +__all__ = ['VGG'] def _cfg(url='', **kwargs): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 254a0748..e15ae4a5 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -6,17 +6,15 @@ From original at https://github.com/danczs/Visformer Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman """ -from copy import deepcopy import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier -from .registry import register_model - +from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['Visformer'] diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 820dc656..5b93628f 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -19,10 +19,10 @@ for some einops/einsum fun Hacked together by / Copyright 2020, Ross Wightman """ -import math import logging -from functools import partial +import math from collections import OrderedDict +from functools import partial from typing import Optional import torch @@ -30,12 +30,17 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\ +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ -from .pretrained import generate_default_cfgs -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv +from ._pretrained import generate_default_cfgs +from ._registry import register_model + + +__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this + _logger = logging.getLogger(__name__) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 5e5113d7..cfdd0a0e 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -13,19 +13,18 @@ They were moved here to keep file sizes sane. Hacked together by / Copyright 2020, Ross Wightman """ -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import StdConv2dSame, StdConv2d, to_2tuple -from .pretrained import generate_default_cfgs +from timm.layers import StdConv2dSame, StdConv2d, to_2tuple +from ._pretrained import generate_default_cfgs +from ._registry import register_model from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem -from .registry import register_model -from timm.models.vision_transformer import _create_vision_transformer +from .vision_transformer import _create_vision_transformer def _cfg(url='', **kwargs): diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 52b3ce45..1a7c2f40 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -4,11 +4,9 @@ NOTE: these models are experimental / WIP, expect changes Hacked together by / Copyright 2022, Ross Wightman """ -import math import logging +import math from functools import partial -from collections import OrderedDict -from dataclasses import dataclass from typing import Optional, Tuple import torch @@ -16,10 +14,12 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple -from .registry import register_model +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/volo.py b/timm/models/volo.py index 735453c8..1117995a 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -20,17 +20,19 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman # See the License for the specific language governing permissions and # limitations under the License. import math -import numpy as np +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ -from timm.models.registry import register_model -from timm.models.helpers import build_model_with_cfg +from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['VOLO'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 39d37195..bf0e4f89 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -15,13 +15,15 @@ from typing import List import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath,\ +from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \ create_attn, create_norm_act_layer, get_norm_act_layer +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['VovNet'] # model_registry will add each entrypoint fn to this # model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 & diff --git a/timm/models/xception.py b/timm/models/xception.py index 99d02c46..99e74b46 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -25,9 +25,9 @@ import torch.jit import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['Xception'] diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 6bbce5e6..e3348e64 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -11,10 +11,11 @@ import torch import torch.nn as nn from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer -from .layers.helpers import to_3tuple -from .registry import register_model +from timm.layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer +from timm.layers.helpers import to_3tuple +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['XceptionAligned'] diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 6802fc84..57c11183 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -19,12 +19,14 @@ import torch.nn as nn from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .vision_transformer import _cfg, Mlp -from .registry import register_model -from .layers import DropPath, trunc_normal_, to_2tuple +from timm.layers import DropPath, trunc_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .cait import ClassAttn -from .fx_features import register_notrace_module +from .vision_transformer import Mlp + +__all__ = ['XCiT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 02f0e250..8613a62c 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn import torch.optim as optim -from timm.models.helpers import group_parameters +from timm.models import group_parameters from .adabelief import AdaBelief from .adafactor import Adafactor diff --git a/timm/version.py b/timm/version.py index 0f19999f..0716d38a 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.8.0dev0' +__version__ = '0.8.1dev0' diff --git a/train.py b/train.py index b85eb6b0..e51d7c90 100755 --- a/train.py +++ b/train.py @@ -31,10 +31,9 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm import utils from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \ - LabelSmoothingCrossEntropy -from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ - convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm +from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm +from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy +from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler @@ -82,7 +81,7 @@ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset parameters group = parser.add_argument_group('Dataset parameters') -# Keep this argument outside of the dataset group because it is positional. +# Keep this argument outside the dataset group because it is positional. parser.add_argument('data', nargs='?', metavar='DIR', const=None, help='path to dataset (positional is *deprecated*, use --data-dir)') parser.add_argument('--data-dir', metavar='DIR', diff --git a/validate.py b/validate.py index 872f27b0..4669fbac 100755 --- a/validate.py +++ b/validate.py @@ -8,22 +8,24 @@ canonical PyTorch, standard Python style, and good performance. Repurpose as you Hacked together by Ross Wightman (https://github.com/rwightman) """ import argparse -import os import csv import glob import json -import time import logging -import torch -import torch.nn as nn -import torch.nn.parallel +import os +import time from collections import OrderedDict from contextlib import suppress from functools import partial -from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm +import torch +import torch.nn as nn +import torch.nn.parallel + from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet -from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\ +from timm.layers import apply_test_time_pool, set_fast_norm +from timm.models import create_model, load_checkpoint, is_model, list_models +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ decay_batch_step, check_batch_size_retry try: