From cda39b35bd7ac4a8053f422802ba65f88dbb6e3c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 9 Dec 2022 14:39:45 -0800 Subject: [PATCH] Add a deprecation phase to module re-org --- benchmark.py | 3 ++- timm/models/_builder.py | 4 ++++ timm/models/_factory.py | 3 +++ timm/models/_features.py | 3 +++ timm/models/_features_fx.py | 4 ++++ timm/models/_helpers.py | 2 ++ timm/models/_hub.py | 3 +++ timm/models/_manipulate.py | 3 +++ timm/models/_pretrained.py | 3 +++ timm/models/_prune.py | 2 ++ timm/models/_registry.py | 2 +- timm/models/factory.py | 4 ++++ timm/models/features.py | 4 ++++ timm/models/fx_features.py | 4 ++++ timm/models/helpers.py | 7 +++++++ timm/models/hub.py | 4 ++++ timm/models/layers/__init__.py | 3 +++ timm/models/registry.py | 4 ++++ 18 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 timm/models/factory.py create mode 100644 timm/models/features.py create mode 100644 timm/models/fx_features.py create mode 100644 timm/models/helpers.py create mode 100644 timm/models/hub.py create mode 100644 timm/models/registry.py diff --git a/benchmark.py b/benchmark.py index 04557a7d..95e2cb5a 100755 --- a/benchmark.py +++ b/benchmark.py @@ -19,7 +19,8 @@ import torch.nn as nn import torch.nn.parallel from timm.data import resolve_data_config -from timm.models import create_model, is_model, list_models, set_fast_norm +from timm.layers import set_fast_norm +from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry diff --git a/timm/models/_builder.py b/timm/models/_builder.py index c99c85f6..f634650e 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -23,6 +23,10 @@ _DOWNLOAD_PROGRESS = False _CHECK_HASH = False +__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained', + 'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg'] + + def _resolve_pretrained_source(pretrained_cfg): cfg_source = pretrained_cfg.get('source', '') pretrained_url = pretrained_cfg.get('url', None) diff --git a/timm/models/_factory.py b/timm/models/_factory.py index 2b050ad6..a8092419 100644 --- a/timm/models/_factory.py +++ b/timm/models/_factory.py @@ -9,6 +9,9 @@ from ._hub import load_model_config_from_hf from ._registry import is_model, model_entrypoint +__all__ = ['parse_model_name', 'safe_model_name', 'create_model'] + + def parse_model_name(model_name): if model_name.startswith('hf_hub'): # NOTE for backwards compat, deprecate hf_hub use diff --git a/timm/models/_features.py b/timm/models/_features.py index 0bc46419..59b080cd 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -17,6 +17,9 @@ import torch import torch.nn as nn +__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] + + class FeatureInfo: def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index 2d4a33c2..10670a1d 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -35,6 +35,10 @@ except ImportError: pass +__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor', + 'FeatureGraphNet', 'GraphExtractNet'] + + def register_notrace_module(module: Type[nn.Module]): """ Any module not under timm.models.layers should get this decorator if we don't want to trace through it. diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py index 2856842d..995292aa 100644 --- a/timm/models/_helpers.py +++ b/timm/models/_helpers.py @@ -12,6 +12,8 @@ import timm.models._builder _logger = logging.getLogger(__name__) +__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint'] + def clean_state_dict(state_dict): # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 2a87ae7e..e6b7d558 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -31,6 +31,9 @@ except ImportError: _logger = logging.getLogger(__name__) +__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf', + 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub'] + def get_cache_dir(child_dir=''): """ diff --git a/timm/models/_manipulate.py b/timm/models/_manipulate.py index 82a922a2..192979fc 100644 --- a/timm/models/_manipulate.py +++ b/timm/models/_manipulate.py @@ -9,6 +9,9 @@ import torch from torch import nn as nn from torch.utils.checkpoint import checkpoint +__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv', + 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq'] + def model_parameters(model, exclude_head=False): if exclude_head: diff --git a/timm/models/_pretrained.py b/timm/models/_pretrained.py index 60f38fd4..c422dab7 100644 --- a/timm/models/_pretrained.py +++ b/timm/models/_pretrained.py @@ -4,6 +4,9 @@ from dataclasses import dataclass, field, replace, asdict from typing import Any, Deque, Dict, Tuple, Optional, Union +__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs'] + + @dataclass class PretrainedCfg: """ diff --git a/timm/models/_prune.py b/timm/models/_prune.py index 0d744e40..4e744dec 100644 --- a/timm/models/_prune.py +++ b/timm/models/_prune.py @@ -5,6 +5,8 @@ from torch import nn as nn from timm.layers import Conv2dSame, BatchNormAct2d, Linear +__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file'] + def extract_layer(model, layer): layer = layer.split('.') diff --git a/timm/models/_registry.py b/timm/models/_registry.py index 97c8fd59..fc7b3437 100644 --- a/timm/models/_registry.py +++ b/timm/models/_registry.py @@ -12,7 +12,7 @@ from typing import List, Optional, Union, Tuple from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag __all__ = [ - 'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', + 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] _module_to_models = defaultdict(set) # dict of sets to check membership of model in module diff --git a/timm/models/factory.py b/timm/models/factory.py new file mode 100644 index 00000000..0ae83dc0 --- /dev/null +++ b/timm/models/factory.py @@ -0,0 +1,4 @@ +from ._factory import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/features.py b/timm/models/features.py new file mode 100644 index 00000000..25605d99 --- /dev/null +++ b/timm/models/features.py @@ -0,0 +1,4 @@ +from ._features import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py new file mode 100644 index 00000000..0ff3a18b --- /dev/null +++ b/timm/models/fx_features.py @@ -0,0 +1,4 @@ +from ._features_fx import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/helpers.py b/timm/models/helpers.py new file mode 100644 index 00000000..6bc82eb8 --- /dev/null +++ b/timm/models/helpers.py @@ -0,0 +1,7 @@ +from ._builder import * +from ._helpers import * +from ._manipulate import * +from ._prune import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/hub.py b/timm/models/hub.py new file mode 100644 index 00000000..074ca025 --- /dev/null +++ b/timm/models/hub.py @@ -0,0 +1,4 @@ +from _hub import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 1bfb95d5..97e70563 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -43,3 +43,6 @@ from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, Scal from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool from timm.layers.trace_utils import _assert, _float_to_int from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/registry.py b/timm/models/registry.py new file mode 100644 index 00000000..58e2e1f4 --- /dev/null +++ b/timm/models/registry.py @@ -0,0 +1,4 @@ +from ._registry import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)