Add a deprecation phase to module re-org

pull/1581/head
Ross Wightman 1 year ago
parent 927f031293
commit cda39b35bd

@ -19,7 +19,8 @@ import torch.nn as nn
import torch.nn.parallel
from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models, set_fast_norm
from timm.layers import set_fast_norm
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry

@ -23,6 +23,10 @@ _DOWNLOAD_PROGRESS = False
_CHECK_HASH = False
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
def _resolve_pretrained_source(pretrained_cfg):
cfg_source = pretrained_cfg.get('source', '')
pretrained_url = pretrained_cfg.get('url', None)

@ -9,6 +9,9 @@ from ._hub import load_model_config_from_hf
from ._registry import is_model, model_entrypoint
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
def parse_model_name(model_name):
if model_name.startswith('hf_hub'):
# NOTE for backwards compat, deprecate hf_hub use

@ -17,6 +17,9 @@ import torch
import torch.nn as nn
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
class FeatureInfo:
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):

@ -35,6 +35,10 @@ except ImportError:
pass
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
'FeatureGraphNet', 'GraphExtractNet']
def register_notrace_module(module: Type[nn.Module]):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.

@ -12,6 +12,8 @@ import timm.models._builder
_logger = logging.getLogger(__name__)
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
def clean_state_dict(state_dict):
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training

@ -31,6 +31,9 @@ except ImportError:
_logger = logging.getLogger(__name__)
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
def get_cache_dir(child_dir=''):
"""

@ -9,6 +9,9 @@ import torch
from torch import nn as nn
from torch.utils.checkpoint import checkpoint
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
def model_parameters(model, exclude_head=False):
if exclude_head:

@ -4,6 +4,9 @@ from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
@dataclass
class PretrainedCfg:
"""

@ -5,6 +5,8 @@ from torch import nn as nn
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
def extract_layer(model, layer):
layer = layer.split('.')

@ -12,7 +12,7 @@ from typing import List, Optional, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
__all__ = [
'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module

@ -0,0 +1,4 @@
from ._factory import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -0,0 +1,4 @@
from ._features import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -0,0 +1,4 @@
from ._features_fx import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -0,0 +1,7 @@
from ._builder import *
from ._helpers import *
from ._manipulate import *
from ._prune import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -0,0 +1,4 @@
from _hub import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -43,3 +43,6 @@ from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, Scal
from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool
from timm.layers.trace_utils import _assert, _float_to_int
from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -0,0 +1,4 @@
from ._registry import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
Loading…
Cancel
Save