Add a deprecation phase to module re-org

pull/1581/head
Ross Wightman 1 year ago
parent 927f031293
commit cda39b35bd

@ -19,7 +19,8 @@ import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from timm.data import resolve_data_config from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models, set_fast_norm from timm.layers import set_fast_norm
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2 from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry

@ -23,6 +23,10 @@ _DOWNLOAD_PROGRESS = False
_CHECK_HASH = False _CHECK_HASH = False
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
def _resolve_pretrained_source(pretrained_cfg): def _resolve_pretrained_source(pretrained_cfg):
cfg_source = pretrained_cfg.get('source', '') cfg_source = pretrained_cfg.get('source', '')
pretrained_url = pretrained_cfg.get('url', None) pretrained_url = pretrained_cfg.get('url', None)

@ -9,6 +9,9 @@ from ._hub import load_model_config_from_hf
from ._registry import is_model, model_entrypoint from ._registry import is_model, model_entrypoint
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
def parse_model_name(model_name): def parse_model_name(model_name):
if model_name.startswith('hf_hub'): if model_name.startswith('hf_hub'):
# NOTE for backwards compat, deprecate hf_hub use # NOTE for backwards compat, deprecate hf_hub use

@ -17,6 +17,9 @@ import torch
import torch.nn as nn import torch.nn as nn
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
class FeatureInfo: class FeatureInfo:
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):

@ -35,6 +35,10 @@ except ImportError:
pass pass
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
'FeatureGraphNet', 'GraphExtractNet']
def register_notrace_module(module: Type[nn.Module]): def register_notrace_module(module: Type[nn.Module]):
""" """
Any module not under timm.models.layers should get this decorator if we don't want to trace through it. Any module not under timm.models.layers should get this decorator if we don't want to trace through it.

@ -12,6 +12,8 @@ import timm.models._builder
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
def clean_state_dict(state_dict): def clean_state_dict(state_dict):
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training

@ -31,6 +31,9 @@ except ImportError:
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
def get_cache_dir(child_dir=''): def get_cache_dir(child_dir=''):
""" """

@ -9,6 +9,9 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
def model_parameters(model, exclude_head=False): def model_parameters(model, exclude_head=False):
if exclude_head: if exclude_head:

@ -4,6 +4,9 @@ from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union from typing import Any, Deque, Dict, Tuple, Optional, Union
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
@dataclass @dataclass
class PretrainedCfg: class PretrainedCfg:
""" """

@ -5,6 +5,8 @@ from torch import nn as nn
from timm.layers import Conv2dSame, BatchNormAct2d, Linear from timm.layers import Conv2dSame, BatchNormAct2d, Linear
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
def extract_layer(model, layer): def extract_layer(model, layer):
layer = layer.split('.') layer = layer.split('.')

@ -12,7 +12,7 @@ from typing import List, Optional, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
__all__ = [ __all__ = [
'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module _module_to_models = defaultdict(set) # dict of sets to check membership of model in module

@ -0,0 +1,4 @@
from ._factory import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -0,0 +1,4 @@
from ._features import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -0,0 +1,4 @@
from ._features_fx import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -0,0 +1,7 @@
from ._builder import *
from ._helpers import *
from ._manipulate import *
from ._prune import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -0,0 +1,4 @@
from _hub import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -43,3 +43,6 @@ from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, Scal
from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool
from timm.layers.trace_utils import _assert, _float_to_int from timm.layers.trace_utils import _assert, _float_to_int
from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)

@ -0,0 +1,4 @@
from ._registry import *
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
Loading…
Cancel
Save