From 927f031293a30afb940fff0bee34b85d9c059b0e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 6 Dec 2022 15:00:06 -0800 Subject: [PATCH 1/2] Major module / path restructure, timm.models.layers -> timm.layers, add _ prefix to all non model modules in timm.models --- avg_checkpoints.py | 2 +- clean_checkpoint.py | 2 +- hubconf.py | 5 +- inference.py | 9 +- tests/test_layers.py | 5 +- tests/test_models.py | 2 +- timm/__init__.py | 2 +- timm/data/readers/class_map.py | 3 +- timm/layers/__init__.py | 44 + timm/{models => }/layers/activations.py | 0 timm/{models => }/layers/activations_jit.py | 0 timm/{models => }/layers/activations_me.py | 0 .../layers/adaptive_avgmax_pool.py | 0 timm/{models => }/layers/attention_pool2d.py | 0 timm/{models => }/layers/blur_pool.py | 0 timm/{models => }/layers/bottleneck_attn.py | 0 timm/{models => }/layers/cbam.py | 0 timm/{models => }/layers/classifier.py | 0 timm/{models => }/layers/cond_conv2d.py | 0 timm/{models => }/layers/config.py | 0 timm/{models => }/layers/conv2d_same.py | 0 timm/{models => }/layers/conv_bn_act.py | 0 timm/{models => }/layers/create_act.py | 0 timm/{models => }/layers/create_attn.py | 0 timm/{models => }/layers/create_conv2d.py | 0 timm/{models => }/layers/create_norm.py | 0 timm/{models => }/layers/create_norm_act.py | 0 timm/{models => }/layers/drop.py | 0 timm/{models => }/layers/eca.py | 0 timm/{models => }/layers/evo_norm.py | 0 timm/{models => }/layers/fast_norm.py | 0 .../layers/filter_response_norm.py | 0 timm/{models => }/layers/gather_excite.py | 0 timm/{models => }/layers/global_context.py | 0 timm/{models => }/layers/halo_attn.py | 0 timm/{models => }/layers/helpers.py | 0 timm/{models => }/layers/inplace_abn.py | 0 timm/{models => }/layers/lambda_layer.py | 0 timm/{models => }/layers/linear.py | 0 timm/{models => }/layers/median_pool.py | 0 timm/{models => }/layers/mixed_conv2d.py | 0 timm/{models => }/layers/ml_decoder.py | 0 timm/{models => }/layers/mlp.py | 0 timm/{models => }/layers/non_local_attn.py | 0 timm/{models => }/layers/norm.py | 0 timm/{models => }/layers/norm_act.py | 0 timm/{models => }/layers/padding.py | 0 timm/{models => }/layers/patch_embed.py | 0 timm/{models => }/layers/pool2d_same.py | 0 timm/{models => }/layers/pos_embed.py | 0 timm/{models => }/layers/selective_kernel.py | 0 timm/{models => }/layers/separable_conv.py | 0 timm/{models => }/layers/space_to_depth.py | 0 timm/{models => }/layers/split_attn.py | 0 timm/{models => }/layers/split_batchnorm.py | 0 timm/{models => }/layers/squeeze_excite.py | 0 timm/{models => }/layers/std_conv.py | 0 timm/{models => }/layers/test_time_pool.py | 0 timm/{models => }/layers/trace_utils.py | 0 timm/{models => }/layers/weight_init.py | 0 timm/models/__init__.py | 22 +- timm/models/_builder.py | 395 ++++++++ ...tnet_blocks.py => _efficientnet_blocks.py} | 3 +- ...et_builder.py => _efficientnet_builder.py} | 4 +- timm/models/{factory.py => _factory.py} | 10 +- timm/models/{features.py => _features.py} | 0 .../{fx_features.py => _features_fx.py} | 10 +- timm/models/_helpers.py | 113 +++ timm/models/{hub.py => _hub.py} | 2 +- timm/models/_manipulate.py | 255 ++++++ timm/models/{pretrained.py => _pretrained.py} | 0 timm/models/_prune.py | 111 +++ .../ecaresnet101d_pruned.txt | 0 .../ecaresnet50d_pruned.txt | 0 .../efficientnet_b1_pruned.txt | 0 .../efficientnet_b2_pruned.txt | 0 .../efficientnet_b3_pruned.txt | 0 timm/models/{registry.py => _registry.py} | 6 +- timm/models/beit.py | 9 +- timm/models/byoanet.py | 4 +- timm/models/byobnet.py | 12 +- timm/models/cait.py | 9 +- timm/models/coat.py | 19 +- timm/models/convit.py | 16 +- timm/models/convmixer.py | 9 +- timm/models/convnext.py | 10 +- timm/models/crossvit.py | 19 +- timm/models/cspnet.py | 14 +- timm/models/deit.py | 6 +- timm/models/densenet.py | 8 +- timm/models/dla.py | 6 +- timm/models/dpn.py | 6 +- timm/models/edgenext.py | 14 +- timm/models/efficientformer.py | 8 +- timm/models/efficientnet.py | 14 +- timm/models/gcvit.py | 11 +- timm/models/ghostnet.py | 11 +- timm/models/gluon_resnet.py | 8 +- timm/models/gluon_xception.py | 6 +- timm/models/hardcorenas.py | 12 +- timm/models/helpers.py | 855 ------------------ timm/models/hrnet.py | 10 +- timm/models/inception_resnet_v2.py | 7 +- timm/models/inception_v3.py | 10 +- timm/models/inception_v4.py | 6 +- timm/models/layers/__init__.py | 83 +- timm/models/levit.py | 12 +- timm/models/maxxvit.py | 18 +- timm/models/mlp_mixer.py | 10 +- timm/models/mobilenetv3.py | 13 +- timm/models/mobilevit.py | 12 +- timm/models/mvitv2.py | 10 +- timm/models/nasnet.py | 6 +- timm/models/nest.py | 14 +- timm/models/nfnet.py | 16 +- timm/models/pit.py | 10 +- timm/models/pnasnet.py | 6 +- timm/models/poolformer.py | 10 +- timm/models/pvt_v2.py | 6 +- timm/models/regnet.py | 11 +- timm/models/res2net.py | 4 +- timm/models/resnest.py | 7 +- timm/models/resnet.py | 15 +- timm/models/resnetv2.py | 11 +- timm/models/rexnet.py | 16 +- timm/models/selecsls.py | 6 +- timm/models/senet.py | 6 +- timm/models/sequencer.py | 10 +- timm/models/sknet.py | 6 +- timm/models/swin_transformer.py | 11 +- timm/models/swin_transformer_v2.py | 10 +- timm/models/swin_transformer_v2_cr.py | 11 +- timm/models/tnt.py | 13 +- timm/models/tresnet.py | 8 +- timm/models/twins.py | 15 +- timm/models/vgg.py | 18 +- timm/models/visformer.py | 10 +- timm/models/vision_transformer.py | 19 +- timm/models/vision_transformer_hybrid.py | 9 +- timm/models/vision_transformer_relpos.py | 14 +- timm/models/volo.py | 10 +- timm/models/vovnet.py | 10 +- timm/models/xception.py | 6 +- timm/models/xception_aligned.py | 9 +- timm/models/xcit.py | 12 +- timm/optim/optim_factory.py | 2 +- timm/version.py | 2 +- train.py | 9 +- validate.py | 16 +- 149 files changed, 1387 insertions(+), 1269 deletions(-) create mode 100644 timm/layers/__init__.py rename timm/{models => }/layers/activations.py (100%) rename timm/{models => }/layers/activations_jit.py (100%) rename timm/{models => }/layers/activations_me.py (100%) rename timm/{models => }/layers/adaptive_avgmax_pool.py (100%) rename timm/{models => }/layers/attention_pool2d.py (100%) rename timm/{models => }/layers/blur_pool.py (100%) rename timm/{models => }/layers/bottleneck_attn.py (100%) rename timm/{models => }/layers/cbam.py (100%) rename timm/{models => }/layers/classifier.py (100%) rename timm/{models => }/layers/cond_conv2d.py (100%) rename timm/{models => }/layers/config.py (100%) rename timm/{models => }/layers/conv2d_same.py (100%) rename timm/{models => }/layers/conv_bn_act.py (100%) rename timm/{models => }/layers/create_act.py (100%) rename timm/{models => }/layers/create_attn.py (100%) rename timm/{models => }/layers/create_conv2d.py (100%) rename timm/{models => }/layers/create_norm.py (100%) rename timm/{models => }/layers/create_norm_act.py (100%) rename timm/{models => }/layers/drop.py (100%) rename timm/{models => }/layers/eca.py (100%) rename timm/{models => }/layers/evo_norm.py (100%) rename timm/{models => }/layers/fast_norm.py (100%) rename timm/{models => }/layers/filter_response_norm.py (100%) rename timm/{models => }/layers/gather_excite.py (100%) rename timm/{models => }/layers/global_context.py (100%) rename timm/{models => }/layers/halo_attn.py (100%) rename timm/{models => }/layers/helpers.py (100%) rename timm/{models => }/layers/inplace_abn.py (100%) rename timm/{models => }/layers/lambda_layer.py (100%) rename timm/{models => }/layers/linear.py (100%) rename timm/{models => }/layers/median_pool.py (100%) rename timm/{models => }/layers/mixed_conv2d.py (100%) rename timm/{models => }/layers/ml_decoder.py (100%) rename timm/{models => }/layers/mlp.py (100%) rename timm/{models => }/layers/non_local_attn.py (100%) rename timm/{models => }/layers/norm.py (100%) rename timm/{models => }/layers/norm_act.py (100%) rename timm/{models => }/layers/padding.py (100%) rename timm/{models => }/layers/patch_embed.py (100%) rename timm/{models => }/layers/pool2d_same.py (100%) rename timm/{models => }/layers/pos_embed.py (100%) rename timm/{models => }/layers/selective_kernel.py (100%) rename timm/{models => }/layers/separable_conv.py (100%) rename timm/{models => }/layers/space_to_depth.py (100%) rename timm/{models => }/layers/split_attn.py (100%) rename timm/{models => }/layers/split_batchnorm.py (100%) rename timm/{models => }/layers/squeeze_excite.py (100%) rename timm/{models => }/layers/std_conv.py (100%) rename timm/{models => }/layers/test_time_pool.py (100%) rename timm/{models => }/layers/trace_utils.py (100%) rename timm/{models => }/layers/weight_init.py (100%) create mode 100644 timm/models/_builder.py rename timm/models/{efficientnet_blocks.py => _efficientnet_blocks.py} (99%) rename timm/models/{efficientnet_builder.py => _efficientnet_builder.py} (99%) rename timm/models/{factory.py => _factory.py} (94%) rename timm/models/{features.py => _features.py} (100%) rename timm/models/{fx_features.py => _features_fx.py} (93%) create mode 100644 timm/models/_helpers.py rename timm/models/{hub.py => _hub.py} (99%) create mode 100644 timm/models/_manipulate.py rename timm/models/{pretrained.py => _pretrained.py} (100%) create mode 100644 timm/models/_prune.py rename timm/models/{pruned => _pruned}/ecaresnet101d_pruned.txt (100%) rename timm/models/{pruned => _pruned}/ecaresnet50d_pruned.txt (100%) rename timm/models/{pruned => _pruned}/efficientnet_b1_pruned.txt (100%) rename timm/models/{pruned => _pruned}/efficientnet_b2_pruned.txt (100%) rename timm/models/{pruned => _pruned}/efficientnet_b3_pruned.txt (100%) rename timm/models/{registry.py => _registry.py} (95%) delete mode 100644 timm/models/helpers.py diff --git a/avg_checkpoints.py b/avg_checkpoints.py index ea8bbe84..83af5bbd 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -16,7 +16,7 @@ import argparse import os import glob import hashlib -from timm.models.helpers import load_state_dict +from timm.models import load_state_dict parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser.add_argument('--input', default='', type=str, metavar='PATH', diff --git a/clean_checkpoint.py b/clean_checkpoint.py index 8ec892b2..17c270db 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -13,7 +13,7 @@ import os import hashlib import shutil from collections import OrderedDict -from timm.models.helpers import load_state_dict +from timm.models import load_state_dict parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', diff --git a/hubconf.py b/hubconf.py index 70fed79a..6b2061ea 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,4 +1,3 @@ dependencies = ['torch'] -from timm.models import registry - -globals().update(registry._model_entrypoints) +import timm +globals().update(timm.models._registry._model_entrypoints) diff --git a/inference.py b/inference.py index bc794840..1509b323 100755 --- a/inference.py +++ b/inference.py @@ -5,11 +5,11 @@ An example inference script that outputs top-k class ids for images in a folder Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ -import os -import time import argparse import json import logging +import os +import time from contextlib import suppress from functools import partial @@ -17,12 +17,11 @@ import numpy as np import pandas as pd import torch -from timm.models import create_model, apply_test_time_pool, load_checkpoint from timm.data import create_dataset, create_loader, resolve_data_config +from timm.layers import apply_test_time_pool +from timm.models import create_model from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser - - try: from apex import amp has_apex = True diff --git a/tests/test_layers.py b/tests/test_layers.py index 508a6aae..da061870 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,10 +1,7 @@ -import pytest import torch import torch.nn as nn -import platform -import os -from timm.models.layers import create_act_layer, get_act_layer, set_layer_config +from timm.layers import create_act_layer, set_layer_config class MLP(nn.Module): diff --git a/tests/test_models.py b/tests/test_models.py index dd1330eb..d6c0052f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,7 +14,7 @@ except ImportError: import timm from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value -from timm.models.fx_features import _leaf_modules, _autowrap_functions +from timm.models._features_fx import _leaf_modules, _autowrap_functions if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests diff --git a/timm/__init__.py b/timm/__init__.py index faf34dbc..3d38cdb9 100644 --- a/timm/__init__.py +++ b/timm/__init__.py @@ -1,4 +1,4 @@ from .version import __version__ +from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \ - is_scriptable, is_exportable, set_scriptable, set_exportable, \ is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/data/readers/class_map.py b/timm/data/readers/class_map.py index 6cf3f57e..885be6e2 100644 --- a/timm/data/readers/class_map.py +++ b/timm/data/readers/class_map.py @@ -1,6 +1,7 @@ import os import pickle + def load_class_map(map_or_filename, root=''): if isinstance(map_or_filename, dict): assert dict, 'class_map dict must be non-empty' @@ -14,7 +15,7 @@ def load_class_map(map_or_filename, root=''): with open(class_map_path) as f: class_to_idx = {v.strip(): k for k, v in enumerate(f)} elif class_map_ext == '.pkl': - with open(class_map_path,'rb') as f: + with open(class_map_path, 'rb') as f: class_to_idx = pickle.load(f) else: assert False, f'Unsupported class map file extension ({class_map_ext}).' diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py new file mode 100644 index 00000000..21c641b6 --- /dev/null +++ b/timm/layers/__init__.py @@ -0,0 +1,44 @@ +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .blur_pool import BlurPool2d +from .classifier import ClassifierHead, create_classifier +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ + set_layer_config +from .conv2d_same import Conv2dSame, conv2d_same +from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct +from .create_act import create_act_layer, get_act_layer, get_act_fn +from .create_attn import get_attn, create_attn +from .create_conv2d import create_conv2d +from .create_norm import get_norm_layer, create_norm_layer +from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ + EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm +from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple +from .inplace_abn import InplaceAbn +from .linear import Linear +from .mixed_conv2d import MixedConv2d +from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d +from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm +from .padding import get_padding, get_same_padding, pad_same +from .patch_embed import PatchEmbed +from .pool2d_same import AvgPool2dSame, create_pool2d +from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from .selective_kernel import SelectiveKernel +from .separable_conv import SeparableConv2d, SeparableConvNormAct +from .space_to_depth import SpaceToDepthModule +from .split_attn import SplitAttn +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .trace_utils import _assert, _float_to_int +from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/activations.py b/timm/layers/activations.py similarity index 100% rename from timm/models/layers/activations.py rename to timm/layers/activations.py diff --git a/timm/models/layers/activations_jit.py b/timm/layers/activations_jit.py similarity index 100% rename from timm/models/layers/activations_jit.py rename to timm/layers/activations_jit.py diff --git a/timm/models/layers/activations_me.py b/timm/layers/activations_me.py similarity index 100% rename from timm/models/layers/activations_me.py rename to timm/layers/activations_me.py diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/layers/adaptive_avgmax_pool.py similarity index 100% rename from timm/models/layers/adaptive_avgmax_pool.py rename to timm/layers/adaptive_avgmax_pool.py diff --git a/timm/models/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py similarity index 100% rename from timm/models/layers/attention_pool2d.py rename to timm/layers/attention_pool2d.py diff --git a/timm/models/layers/blur_pool.py b/timm/layers/blur_pool.py similarity index 100% rename from timm/models/layers/blur_pool.py rename to timm/layers/blur_pool.py diff --git a/timm/models/layers/bottleneck_attn.py b/timm/layers/bottleneck_attn.py similarity index 100% rename from timm/models/layers/bottleneck_attn.py rename to timm/layers/bottleneck_attn.py diff --git a/timm/models/layers/cbam.py b/timm/layers/cbam.py similarity index 100% rename from timm/models/layers/cbam.py rename to timm/layers/cbam.py diff --git a/timm/models/layers/classifier.py b/timm/layers/classifier.py similarity index 100% rename from timm/models/layers/classifier.py rename to timm/layers/classifier.py diff --git a/timm/models/layers/cond_conv2d.py b/timm/layers/cond_conv2d.py similarity index 100% rename from timm/models/layers/cond_conv2d.py rename to timm/layers/cond_conv2d.py diff --git a/timm/models/layers/config.py b/timm/layers/config.py similarity index 100% rename from timm/models/layers/config.py rename to timm/layers/config.py diff --git a/timm/models/layers/conv2d_same.py b/timm/layers/conv2d_same.py similarity index 100% rename from timm/models/layers/conv2d_same.py rename to timm/layers/conv2d_same.py diff --git a/timm/models/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py similarity index 100% rename from timm/models/layers/conv_bn_act.py rename to timm/layers/conv_bn_act.py diff --git a/timm/models/layers/create_act.py b/timm/layers/create_act.py similarity index 100% rename from timm/models/layers/create_act.py rename to timm/layers/create_act.py diff --git a/timm/models/layers/create_attn.py b/timm/layers/create_attn.py similarity index 100% rename from timm/models/layers/create_attn.py rename to timm/layers/create_attn.py diff --git a/timm/models/layers/create_conv2d.py b/timm/layers/create_conv2d.py similarity index 100% rename from timm/models/layers/create_conv2d.py rename to timm/layers/create_conv2d.py diff --git a/timm/models/layers/create_norm.py b/timm/layers/create_norm.py similarity index 100% rename from timm/models/layers/create_norm.py rename to timm/layers/create_norm.py diff --git a/timm/models/layers/create_norm_act.py b/timm/layers/create_norm_act.py similarity index 100% rename from timm/models/layers/create_norm_act.py rename to timm/layers/create_norm_act.py diff --git a/timm/models/layers/drop.py b/timm/layers/drop.py similarity index 100% rename from timm/models/layers/drop.py rename to timm/layers/drop.py diff --git a/timm/models/layers/eca.py b/timm/layers/eca.py similarity index 100% rename from timm/models/layers/eca.py rename to timm/layers/eca.py diff --git a/timm/models/layers/evo_norm.py b/timm/layers/evo_norm.py similarity index 100% rename from timm/models/layers/evo_norm.py rename to timm/layers/evo_norm.py diff --git a/timm/models/layers/fast_norm.py b/timm/layers/fast_norm.py similarity index 100% rename from timm/models/layers/fast_norm.py rename to timm/layers/fast_norm.py diff --git a/timm/models/layers/filter_response_norm.py b/timm/layers/filter_response_norm.py similarity index 100% rename from timm/models/layers/filter_response_norm.py rename to timm/layers/filter_response_norm.py diff --git a/timm/models/layers/gather_excite.py b/timm/layers/gather_excite.py similarity index 100% rename from timm/models/layers/gather_excite.py rename to timm/layers/gather_excite.py diff --git a/timm/models/layers/global_context.py b/timm/layers/global_context.py similarity index 100% rename from timm/models/layers/global_context.py rename to timm/layers/global_context.py diff --git a/timm/models/layers/halo_attn.py b/timm/layers/halo_attn.py similarity index 100% rename from timm/models/layers/halo_attn.py rename to timm/layers/halo_attn.py diff --git a/timm/models/layers/helpers.py b/timm/layers/helpers.py similarity index 100% rename from timm/models/layers/helpers.py rename to timm/layers/helpers.py diff --git a/timm/models/layers/inplace_abn.py b/timm/layers/inplace_abn.py similarity index 100% rename from timm/models/layers/inplace_abn.py rename to timm/layers/inplace_abn.py diff --git a/timm/models/layers/lambda_layer.py b/timm/layers/lambda_layer.py similarity index 100% rename from timm/models/layers/lambda_layer.py rename to timm/layers/lambda_layer.py diff --git a/timm/models/layers/linear.py b/timm/layers/linear.py similarity index 100% rename from timm/models/layers/linear.py rename to timm/layers/linear.py diff --git a/timm/models/layers/median_pool.py b/timm/layers/median_pool.py similarity index 100% rename from timm/models/layers/median_pool.py rename to timm/layers/median_pool.py diff --git a/timm/models/layers/mixed_conv2d.py b/timm/layers/mixed_conv2d.py similarity index 100% rename from timm/models/layers/mixed_conv2d.py rename to timm/layers/mixed_conv2d.py diff --git a/timm/models/layers/ml_decoder.py b/timm/layers/ml_decoder.py similarity index 100% rename from timm/models/layers/ml_decoder.py rename to timm/layers/ml_decoder.py diff --git a/timm/models/layers/mlp.py b/timm/layers/mlp.py similarity index 100% rename from timm/models/layers/mlp.py rename to timm/layers/mlp.py diff --git a/timm/models/layers/non_local_attn.py b/timm/layers/non_local_attn.py similarity index 100% rename from timm/models/layers/non_local_attn.py rename to timm/layers/non_local_attn.py diff --git a/timm/models/layers/norm.py b/timm/layers/norm.py similarity index 100% rename from timm/models/layers/norm.py rename to timm/layers/norm.py diff --git a/timm/models/layers/norm_act.py b/timm/layers/norm_act.py similarity index 100% rename from timm/models/layers/norm_act.py rename to timm/layers/norm_act.py diff --git a/timm/models/layers/padding.py b/timm/layers/padding.py similarity index 100% rename from timm/models/layers/padding.py rename to timm/layers/padding.py diff --git a/timm/models/layers/patch_embed.py b/timm/layers/patch_embed.py similarity index 100% rename from timm/models/layers/patch_embed.py rename to timm/layers/patch_embed.py diff --git a/timm/models/layers/pool2d_same.py b/timm/layers/pool2d_same.py similarity index 100% rename from timm/models/layers/pool2d_same.py rename to timm/layers/pool2d_same.py diff --git a/timm/models/layers/pos_embed.py b/timm/layers/pos_embed.py similarity index 100% rename from timm/models/layers/pos_embed.py rename to timm/layers/pos_embed.py diff --git a/timm/models/layers/selective_kernel.py b/timm/layers/selective_kernel.py similarity index 100% rename from timm/models/layers/selective_kernel.py rename to timm/layers/selective_kernel.py diff --git a/timm/models/layers/separable_conv.py b/timm/layers/separable_conv.py similarity index 100% rename from timm/models/layers/separable_conv.py rename to timm/layers/separable_conv.py diff --git a/timm/models/layers/space_to_depth.py b/timm/layers/space_to_depth.py similarity index 100% rename from timm/models/layers/space_to_depth.py rename to timm/layers/space_to_depth.py diff --git a/timm/models/layers/split_attn.py b/timm/layers/split_attn.py similarity index 100% rename from timm/models/layers/split_attn.py rename to timm/layers/split_attn.py diff --git a/timm/models/layers/split_batchnorm.py b/timm/layers/split_batchnorm.py similarity index 100% rename from timm/models/layers/split_batchnorm.py rename to timm/layers/split_batchnorm.py diff --git a/timm/models/layers/squeeze_excite.py b/timm/layers/squeeze_excite.py similarity index 100% rename from timm/models/layers/squeeze_excite.py rename to timm/layers/squeeze_excite.py diff --git a/timm/models/layers/std_conv.py b/timm/layers/std_conv.py similarity index 100% rename from timm/models/layers/std_conv.py rename to timm/layers/std_conv.py diff --git a/timm/models/layers/test_time_pool.py b/timm/layers/test_time_pool.py similarity index 100% rename from timm/models/layers/test_time_pool.py rename to timm/layers/test_time_pool.py diff --git a/timm/models/layers/trace_utils.py b/timm/layers/trace_utils.py similarity index 100% rename from timm/models/layers/trace_utils.py rename to timm/layers/trace_utils.py diff --git a/timm/models/layers/weight_init.py b/timm/layers/weight_init.py similarity index 100% rename from timm/models/layers/weight_init.py rename to timm/layers/weight_init.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 301186dd..5ecc8915 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -64,12 +64,18 @@ from .xception import * from .xception_aligned import * from .xcit import * -from .factory import create_model, parse_model_name, safe_model_name -from .helpers import load_checkpoint, resume_checkpoint, model_parameters -from .layers import TestTimePoolHead, apply_test_time_pool -from .layers import convert_splitbn_model, convert_sync_batchnorm -from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit -from .layers import set_fast_norm -from .pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag -from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\ +from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \ + set_pretrained_download_progress, set_pretrained_check_hash +from ._factory import create_model, parse_model_name, safe_model_name +from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet +from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \ + register_notrace_module, register_notrace_function +from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint +from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub +from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \ + group_modules, group_parameters, checkpoint_seq, adapt_input_conv +from ._pretrained import PretrainedCfg, DefaultCfg, \ + filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag +from ._prune import adapt_model_from_string +from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \ is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/models/_builder.py b/timm/models/_builder.py new file mode 100644 index 00000000..c99c85f6 --- /dev/null +++ b/timm/models/_builder.py @@ -0,0 +1,395 @@ +import dataclasses +import logging +from copy import deepcopy +from typing import Optional, Dict, Callable, Any, Tuple + +from torch import nn as nn +from torch.hub import load_state_dict_from_url + +from timm.models._features import FeatureListNet, FeatureHookNet +from timm.models._features_fx import FeatureGraphNet +from timm.models._helpers import load_state_dict +from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf +from timm.models._manipulate import adapt_input_conv +from timm.models._pretrained import PretrainedCfg +from timm.models._prune import adapt_model_from_file +from timm.models._registry import get_pretrained_cfg + +_logger = logging.getLogger(__name__) + +# Global variables for rarely used pretrained checkpoint download progress and hash check. +# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle. +_DOWNLOAD_PROGRESS = False +_CHECK_HASH = False + + +def _resolve_pretrained_source(pretrained_cfg): + cfg_source = pretrained_cfg.get('source', '') + pretrained_url = pretrained_cfg.get('url', None) + pretrained_file = pretrained_cfg.get('file', None) + hf_hub_id = pretrained_cfg.get('hf_hub_id', None) + # resolve where to load pretrained weights from + load_from = '' + pretrained_loc = '' + if cfg_source == 'hf-hub' and has_hf_hub(necessary=True): + # hf-hub specified as source via model identifier + load_from = 'hf-hub' + assert hf_hub_id + pretrained_loc = hf_hub_id + else: + # default source == timm or unspecified + if pretrained_file: + load_from = 'file' + pretrained_loc = pretrained_file + elif pretrained_url: + load_from = 'url' + pretrained_loc = pretrained_url + elif hf_hub_id and has_hf_hub(necessary=True): + # hf-hub available as alternate weight source in default_cfg + load_from = 'hf-hub' + pretrained_loc = hf_hub_id + if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None): + # if a filename override is set, return tuple for location w/ (hub_id, filename) + pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] + return load_from, pretrained_loc + + +def set_pretrained_download_progress(enable=True): + """ Set download progress for pretrained weights on/off (globally). """ + global _DOWNLOAD_PROGRESS + _DOWNLOAD_PROGRESS = enable + + +def set_pretrained_check_hash(enable=True): + """ Set hash checking for pretrained weights on/off (globally). """ + global _CHECK_HASH + _CHECK_HASH = enable + + +def load_custom_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + load_fn: Optional[Callable] = None, +): + r"""Loads a custom (read non .pth) weight file + + Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls + a passed in custom load fun, or the `load_pretrained` model member fn. + + If the object is already present in `model_dir`, it's deserialized and returned. + The default value of `model_dir` is ``/checkpoints`` where + `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + model: The instantiated model to load weights into + pretrained_cfg (dict): Default pretrained model cfg + load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named + 'laod_pretrained' on the model will be called if it exists + """ + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) + if not pretrained_cfg: + _logger.warning("Invalid pretrained config, cannot load weights.") + return + + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if not load_from: + _logger.warning("No pretrained weights exist for this model. Using random initialization.") + return + if load_from == 'hf-hub': # FIXME + _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") + elif load_from == 'url': + pretrained_loc = download_cached_file( + pretrained_loc, + check_hash=_CHECK_HASH, + progress=_DOWNLOAD_PROGRESS + ) + + if load_fn is not None: + load_fn(model, pretrained_loc) + elif hasattr(model, 'load_pretrained'): + model.load_pretrained(pretrained_loc) + else: + _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") + + +def load_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + num_classes: int = 1000, + in_chans: int = 3, + filter_fn: Optional[Callable] = None, + strict: bool = True, +): + """ Load pretrained checkpoint + + Args: + model (nn.Module) : PyTorch model module + pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset + num_classes (int): num_classes for target model + in_chans (int): in_chans for target model + filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) + strict (bool): strict load of checkpoint + + """ + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) + if not pretrained_cfg: + _logger.warning("Invalid pretrained config, cannot load weights.") + return + + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if load_from == 'file': + _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') + state_dict = load_state_dict(pretrained_loc) + elif load_from == 'url': + _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') + state_dict = load_state_dict_from_url( + pretrained_loc, + map_location='cpu', + progress=_DOWNLOAD_PROGRESS, + check_hash=_CHECK_HASH, + ) + elif load_from == 'hf-hub': + _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') + if isinstance(pretrained_loc, (list, tuple)): + state_dict = load_state_dict_from_hf(*pretrained_loc) + else: + state_dict = load_state_dict_from_hf(pretrained_loc) + else: + _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") + return + + if filter_fn is not None: + # for backwards compat with filter fn that take one arg, try one first, the two + try: + state_dict = filter_fn(state_dict) + except TypeError: + state_dict = filter_fn(state_dict, model) + + input_convs = pretrained_cfg.get('first_conv', None) + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + try: + state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) + _logger.info( + f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + _logger.warning( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') + + classifiers = pretrained_cfg.get('classifier', None) + label_offset = pretrained_cfg.get('label_offset', 0) + if classifiers is not None: + if isinstance(classifiers, str): + classifiers = (classifiers,) + if num_classes != pretrained_cfg['num_classes']: + for classifier_name in classifiers: + # completely discard fully connected if model num_classes doesn't match pretrained weights + state_dict.pop(classifier_name + '.weight', None) + state_dict.pop(classifier_name + '.bias', None) + strict = False + elif label_offset > 0: + for classifier_name in classifiers: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + + model.load_state_dict(state_dict, strict=strict) + + +def pretrained_cfg_for_features(pretrained_cfg): + pretrained_cfg = deepcopy(pretrained_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size? + for tr in to_remove: + pretrained_cfg.pop(tr, None) + return pretrained_cfg + + +def _filter_kwargs(kwargs, names): + if not kwargs or not names: + return + for n in names: + kwargs.pop(n, None) + + +def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter): + """ Update the default_cfg and kwargs before passing to model + + Args: + pretrained_cfg: input pretrained cfg (updated in-place) + kwargs: keyword args passed to model build fn (updated in-place) + kwargs_filter: keyword arg keys that must be removed before model __init__ + """ + # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if pretrained_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + + for n in default_kwarg_names: + # for legacy reasons, model __init__args uses img_size + in_chans as separate args while + # pretrained_cfg has one input_size=(C, H ,W) entry + if n == 'img_size': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[-2:]) + elif n == 'in_chans': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[0]) + else: + default_val = pretrained_cfg.get(n, None) + if default_val is not None: + kwargs.setdefault(n, pretrained_cfg[n]) + + # Filter keyword args for task specific model variants (some 'features only' models, etc.) + _filter_kwargs(kwargs, names=kwargs_filter) + + +def resolve_pretrained_cfg( + variant: str, + pretrained_cfg=None, + pretrained_cfg_overlay=None, +) -> PretrainedCfg: + model_with_tag = variant + pretrained_tag = None + if pretrained_cfg: + if isinstance(pretrained_cfg, dict): + # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg + pretrained_cfg = PretrainedCfg(**pretrained_cfg) + elif isinstance(pretrained_cfg, str): + pretrained_tag = pretrained_cfg + pretrained_cfg = None + + # fallback to looking up pretrained cfg in model registry by variant identifier + if not pretrained_cfg: + if pretrained_tag: + model_with_tag = '.'.join([variant, pretrained_tag]) + pretrained_cfg = get_pretrained_cfg(model_with_tag) + + if not pretrained_cfg: + _logger.warning( + f"No pretrained configuration specified for {model_with_tag} model. Using a default." + f" Please add a config to the model pretrained_cfg registry or pass explicitly.") + pretrained_cfg = PretrainedCfg() # instance with defaults + + pretrained_cfg_overlay = pretrained_cfg_overlay or {} + if not pretrained_cfg.architecture: + pretrained_cfg_overlay.setdefault('architecture', variant) + pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay) + + return pretrained_cfg + + +def build_model_with_cfg( + model_cls: Callable, + variant: str, + pretrained: bool, + pretrained_cfg: Optional[Dict] = None, + pretrained_cfg_overlay: Optional[Dict] = None, + model_cfg: Optional[Any] = None, + feature_cfg: Optional[Dict] = None, + pretrained_strict: bool = True, + pretrained_filter_fn: Optional[Callable] = None, + kwargs_filter: Optional[Tuple[str]] = None, + **kwargs, +): + """ Build model with specified default_cfg and optional model_cfg + + This helper fn aids in the construction of a model including: + * handling default_cfg and associated pretrained weight loading + * passing through optional model_cfg for models with config based arch spec + * features_only model adaptation + * pruning config / model adaptation + + Args: + model_cls (nn.Module): model class + variant (str): model variant name + pretrained (bool): load pretrained weights + pretrained_cfg (dict): model's pretrained weight/task config + model_cfg (Optional[Dict]): model's architecture config + feature_cfg (Optional[Dict]: feature extraction adapter config + pretrained_strict (bool): load pretrained weights strictly + pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights + kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model + **kwargs: model args passed through to model __init__ + """ + pruned = kwargs.pop('pruned', False) + features = False + feature_cfg = feature_cfg or {} + + # resolve and update model pretrained config and model kwargs + pretrained_cfg = resolve_pretrained_cfg( + variant, + pretrained_cfg=pretrained_cfg, + pretrained_cfg_overlay=pretrained_cfg_overlay + ) + + # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model + pretrained_cfg = pretrained_cfg.to_dict() + + _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) + + # Setup for feature extraction wrapper done at end of this fn + if kwargs.pop('features_only', False): + features = True + feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) + if 'out_indices' in kwargs: + feature_cfg['out_indices'] = kwargs.pop('out_indices') + + # Instantiate the model + if model_cfg is None: + model = model_cls(**kwargs) + else: + model = model_cls(cfg=model_cfg, **kwargs) + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat + + if pruned: + model = adapt_model_from_file(model, variant) + + # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) + if pretrained: + if pretrained_cfg.get('custom_load', False): + load_custom_pretrained( + model, + pretrained_cfg=pretrained_cfg, + ) + else: + load_pretrained( + model, + pretrained_cfg=pretrained_cfg, + num_classes=num_classes_pretrained, + in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, + strict=pretrained_strict, + ) + + # Wrap the model in a feature extraction module if enabled + if features: + feature_cls = FeatureListNet + if 'feature_cls' in feature_cfg: + feature_cls = feature_cfg.pop('feature_cls') + if isinstance(feature_cls, str): + feature_cls = feature_cls.lower() + if 'hook' in feature_cls: + feature_cls = FeatureHookNet + elif feature_cls == 'fx': + feature_cls = FeatureGraphNet + else: + assert False, f'Unknown feature class {feature_cls}' + model = feature_cls(model, **feature_cfg) + model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat + + return model diff --git a/timm/models/efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py similarity index 99% rename from timm/models/efficientnet_blocks.py rename to timm/models/_efficientnet_blocks.py index 34a31757..92b849e4 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -2,13 +2,12 @@ Hacked together by / Copyright 2019, Ross Wightman """ -import math import torch import torch.nn as nn from torch.nn import functional as F -from .layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer +from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer __all__ = [ 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] diff --git a/timm/models/efficientnet_builder.py b/timm/models/_efficientnet_builder.py similarity index 99% rename from timm/models/efficientnet_builder.py rename to timm/models/_efficientnet_builder.py index 67d15a86..e6cd05ae 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -14,8 +14,8 @@ from functools import partial import torch.nn as nn -from .efficientnet_blocks import * -from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible +from ._efficientnet_blocks import * +from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] diff --git a/timm/models/factory.py b/timm/models/_factory.py similarity index 94% rename from timm/models/factory.py rename to timm/models/_factory.py index 9e06c1aa..2b050ad6 100644 --- a/timm/models/factory.py +++ b/timm/models/_factory.py @@ -2,11 +2,11 @@ import os from typing import Any, Dict, Optional, Union from urllib.parse import urlsplit -from .pretrained import PretrainedCfg, split_model_name_tag -from .helpers import load_checkpoint -from .hub import load_model_config_from_hf -from .layers import set_layer_config -from .registry import is_model, model_entrypoint +from timm.layers import set_layer_config +from ._pretrained import PretrainedCfg, split_model_name_tag +from ._helpers import load_checkpoint +from ._hub import load_model_config_from_hf +from ._registry import is_model, model_entrypoint def parse_model_name(model_name): diff --git a/timm/models/features.py b/timm/models/_features.py similarity index 100% rename from timm/models/features.py rename to timm/models/_features.py diff --git a/timm/models/fx_features.py b/timm/models/_features_fx.py similarity index 93% rename from timm/models/fx_features.py rename to timm/models/_features_fx.py index b09381b7..2d4a33c2 100644 --- a/timm/models/fx_features.py +++ b/timm/models/_features_fx.py @@ -6,7 +6,7 @@ from typing import Callable, List, Dict, Union, Type import torch from torch import nn -from .features import _get_feature_info +from ._features import _get_feature_info try: from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor @@ -15,9 +15,9 @@ except ImportError: has_fx_feature_extraction = False # Layers we went to treat as leaf modules -from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame -from .layers.non_local_attn import BilinearAttnTransform -from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame +from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame +from timm.layers.non_local_attn import BilinearAttnTransform +from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here # BUT modules from timm.models should use the registration mechanism below @@ -29,7 +29,7 @@ _leaf_modules = { } try: - from .layers import InplaceAbn + from timm.layers import InplaceAbn _leaf_modules.add(InplaceAbn) except ImportError: pass diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py new file mode 100644 index 00000000..2856842d --- /dev/null +++ b/timm/models/_helpers.py @@ -0,0 +1,113 @@ +""" Model creation / weight loading / state_dict helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import os +from collections import OrderedDict + +import torch + +import timm.models._builder + +_logger = logging.getLogger(__name__) + + +def clean_state_dict(state_dict): + # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training + cleaned_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] if k.startswith('module.') else k + cleaned_state_dict[name] = v + return cleaned_state_dict + + +def load_state_dict(checkpoint_path, use_ema=True): + if checkpoint_path and os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + state_dict_key = '' + if isinstance(checkpoint, dict): + if use_ema and checkpoint.get('state_dict_ema', None) is not None: + state_dict_key = 'state_dict_ema' + elif use_ema and checkpoint.get('model_ema', None) is not None: + state_dict_key = 'model_ema' + elif 'state_dict' in checkpoint: + state_dict_key = 'state_dict' + elif 'model' in checkpoint: + state_dict_key = 'model' + state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint) + _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) + return state_dict + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False): + if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, 'load_pretrained'): + timm.models._model_builder.load_pretrained(checkpoint_path) + else: + raise NotImplementedError('Model cannot load numpy checkpoint') + return + state_dict = load_state_dict(checkpoint_path, use_ema) + if remap: + state_dict = remap_checkpoint(model, state_dict) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def remap_checkpoint(model, state_dict, allow_reshape=True): + """ remap checkpoint by iterating over state dicts in order (ignoring original keys). + This assumes models (and originating state dict) were created with params registered in same order. + """ + out_dict = {} + for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): + assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + if va.shape != vb.shape: + if allow_reshape: + vb = vb.reshape(va.shape) + else: + assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + out_dict[ka] = vb + return out_dict + + +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): + resume_epoch = None + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + _logger.info('Restoring model state from checkpoint...') + state_dict = clean_state_dict(checkpoint['state_dict']) + model.load_state_dict(state_dict) + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + else: + model.load_state_dict(checkpoint) + if log_info: + _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + diff --git a/timm/models/hub.py b/timm/models/_hub.py similarity index 99% rename from timm/models/hub.py rename to timm/models/_hub.py index 18c5444a..2a87ae7e 100644 --- a/timm/models/hub.py +++ b/timm/models/_hub.py @@ -15,7 +15,7 @@ except ImportError: from torch.hub import _get_torch_home as get_dir from timm import __version__ -from timm.models.pretrained import filter_pretrained_cfg +from timm.models._pretrained import filter_pretrained_cfg try: from huggingface_hub import ( diff --git a/timm/models/_manipulate.py b/timm/models/_manipulate.py new file mode 100644 index 00000000..82a922a2 --- /dev/null +++ b/timm/models/_manipulate.py @@ -0,0 +1,255 @@ +import collections.abc +import math +import re +from collections import defaultdict +from itertools import chain +from typing import Callable, Union, Dict + +import torch +from torch import nn as nn +from torch.utils.checkpoint import checkpoint + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() + + +def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): + if not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + yield name, module + + +def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False): + if module._parameters and not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules_with_params( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if module._parameters and depth_first and include_root: + yield name, module + + +MATCH_PREV_GROUP = (99999,) + + +def group_with_matcher( + named_objects, + group_matcher: Union[Dict, Callable], + output_values: bool = False, + reverse: bool = False +): + if isinstance(group_matcher, dict): + # dictionary matcher contains a dict of raw-string regex expr that must be compiled + compiled = [] + for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): + if mspec is None: + continue + # map all matching specifications into 3-tuple (compiled re, prefix, suffix) + if isinstance(mspec, (tuple, list)): + # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) + for sspec in mspec: + compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] + else: + compiled += [(re.compile(mspec), (group_ordinal,), None)] + group_matcher = compiled + + def _get_grouping(name): + if isinstance(group_matcher, (list, tuple)): + for match_fn, prefix, suffix in group_matcher: + r = match_fn.match(name) + if r: + parts = (prefix, r.groups(), suffix) + # map all tuple elem to int for numeric sort, filter out None entries + return tuple(map(float, chain.from_iterable(filter(None, parts)))) + return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal + else: + ord = group_matcher(name) + if not isinstance(ord, collections.abc.Iterable): + return ord, + return tuple(ord) + + # map layers into groups via ordinals (ints or tuples of ints) from matcher + grouping = defaultdict(list) + for k, v in named_objects: + grouping[_get_grouping(k)].append(v if output_values else k) + + # remap to integers + layer_id_to_param = defaultdict(list) + lid = -1 + for k in sorted(filter(lambda x: x is not None, grouping.keys())): + if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: + lid += 1 + layer_id_to_param[lid].extend(grouping[k]) + + if reverse: + assert not output_values, "reverse mapping only sensible for name output" + # output reverse mapping + param_to_layer_id = {} + for lid, lm in layer_id_to_param.items(): + for n in lm: + param_to_layer_id[n] = lid + return param_to_layer_id + + return layer_id_to_param + + +def group_parameters( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) + + +def group_modules( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) + + +def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'): + prefix_is_tuple = isinstance(prefix, tuple) + if isinstance(module_types, str): + if module_types == 'container': + module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict) + else: + module_types = (nn.Sequential,) + for name, module in named_modules: + if depth and isinstance(module, module_types): + yield from flatten_modules( + module.named_children(), + depth - 1, + prefix=(name,) if prefix_is_tuple else name, + module_types=module_types, + ) + else: + if prefix_is_tuple: + name = prefix + (name,) + yield name, module + else: + if prefix: + name = '.'.join([prefix, name]) + yield name, module + + +def checkpoint_seq( + functions, + x, + every=1, + flatten=False, + skip_last=False, + preserve_rng_state=True +): + r"""A helper function for checkpointing sequential models. + + Sequential models execute a list of modules/functions in order + (sequentially). Therefore, we can divide such a sequence into segments + and checkpoint each segment. All segments except run in :func:`torch.no_grad` + manner, i.e., not storing the intermediate activations. The inputs of each + checkpointed segment will be saved for re-running the segment in the backward pass. + + See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. + + .. warning:: + Checkpointing currently only supports :func:`torch.autograd.backward` + and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` + is not supported. + + .. warning: + At least one of the inputs needs to have :code:`requires_grad=True` if + grads are needed for model inputs, otherwise the checkpointed part of the + model won't have gradients. + + Args: + functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. + x: A Tensor that is input to :attr:`functions` + every: checkpoint every-n functions (default: 1) + flatten (bool): flatten nn.Sequential of nn.Sequentials + skip_last (bool): skip checkpointing the last function in the sequence if True + preserve_rng_state (bool, optional, default=True): Omit stashing and restoring + the RNG state during each checkpoint. + + Returns: + Output of running :attr:`functions` sequentially on :attr:`*inputs` + + Example: + >>> model = nn.Sequential(...) + >>> input_var = checkpoint_seq(model, input_var, every=2) + """ + def run_function(start, end, functions): + def forward(_x): + for j in range(start, end + 1): + _x = functions[j](_x) + return _x + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = functions.children() + if flatten: + functions = chain.from_iterable(functions) + if not isinstance(functions, (tuple, list)): + functions = tuple(functions) + + num_checkpointed = len(functions) + if skip_last: + num_checkpointed -= 1 + end = -1 + for start in range(0, num_checkpointed, every): + end = min(start + every - 1, num_checkpointed - 1) + x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) + if skip_last: + return run_function(end + 1, len(functions) - 1, functions)(x) + return x + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight diff --git a/timm/models/pretrained.py b/timm/models/_pretrained.py similarity index 100% rename from timm/models/pretrained.py rename to timm/models/_pretrained.py diff --git a/timm/models/_prune.py b/timm/models/_prune.py new file mode 100644 index 00000000..0d744e40 --- /dev/null +++ b/timm/models/_prune.py @@ -0,0 +1,111 @@ +import os +from copy import deepcopy + +from torch import nn as nn + +from timm.layers import Conv2dSame, BatchNormAct2d, Linear + + +def extract_layer(model, layer): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + if not hasattr(model, 'module') and layer[0] == 'module': + layer = layer[1:] + for l in layer: + if hasattr(module, l): + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + else: + return module + return module + + +def set_layer(model, layer, val): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + lst_index = 0 + module2 = module + for l in layer: + if hasattr(module2, l): + if not l.isdigit(): + module2 = getattr(module2, l) + else: + module2 = module2[int(l)] + lst_index += 1 + lst_index -= 1 + for l in layer[:lst_index]: + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + l = layer[lst_index] + setattr(module, l, val) + + +def adapt_model_from_string(parent_module, model_string): + separator = '***' + state_dict = {} + lst_shape = model_string.split(separator) + for k in lst_shape: + k = k.split(':') + key = k[0] + shape = k[1][1:-1].split(',') + if shape[0] != '': + state_dict[key] = [int(i) for i in shape] + + new_module = deepcopy(parent_module) + for n, m in parent_module.named_modules(): + old_module = extract_layer(parent_module, n) + if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): + if isinstance(old_module, Conv2dSame): + conv = Conv2dSame + else: + conv = nn.Conv2d + s = state_dict[n + '.weight'] + in_channels = s[1] + out_channels = s[0] + g = 1 + if old_module.groups > 1: + in_channels = out_channels + g = in_channels + new_conv = conv( + in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, + bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, + groups=g, stride=old_module.stride) + set_layer(new_module, n, new_conv) + elif isinstance(old_module, BatchNormAct2d): + new_bn = BatchNormAct2d( + state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + new_bn.drop = old_module.drop + new_bn.act = old_module.act + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.BatchNorm2d): + new_bn = nn.BatchNorm2d( + num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.Linear): + # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? + num_features = state_dict[n + '.weight'][1] + new_fc = Linear( + in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) + set_layer(new_module, n, new_fc) + if hasattr(new_module, 'num_features'): + new_module.num_features = num_features + new_module.eval() + parent_module.eval() + + return new_module + + +def adapt_model_from_file(parent_module, model_variant): + adapt_file = os.path.join(os.path.dirname(__file__), '_pruned', model_variant + '.txt') + with open(adapt_file, 'r') as f: + return adapt_model_from_string(parent_module, f.read().strip()) diff --git a/timm/models/pruned/ecaresnet101d_pruned.txt b/timm/models/_pruned/ecaresnet101d_pruned.txt similarity index 100% rename from timm/models/pruned/ecaresnet101d_pruned.txt rename to timm/models/_pruned/ecaresnet101d_pruned.txt diff --git a/timm/models/pruned/ecaresnet50d_pruned.txt b/timm/models/_pruned/ecaresnet50d_pruned.txt similarity index 100% rename from timm/models/pruned/ecaresnet50d_pruned.txt rename to timm/models/_pruned/ecaresnet50d_pruned.txt diff --git a/timm/models/pruned/efficientnet_b1_pruned.txt b/timm/models/_pruned/efficientnet_b1_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b1_pruned.txt rename to timm/models/_pruned/efficientnet_b1_pruned.txt diff --git a/timm/models/pruned/efficientnet_b2_pruned.txt b/timm/models/_pruned/efficientnet_b2_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b2_pruned.txt rename to timm/models/_pruned/efficientnet_b2_pruned.txt diff --git a/timm/models/pruned/efficientnet_b3_pruned.txt b/timm/models/_pruned/efficientnet_b3_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b3_pruned.txt rename to timm/models/_pruned/efficientnet_b3_pruned.txt diff --git a/timm/models/registry.py b/timm/models/_registry.py similarity index 95% rename from timm/models/registry.py rename to timm/models/_registry.py index 159ffb5f..97c8fd59 100644 --- a/timm/models/registry.py +++ b/timm/models/_registry.py @@ -9,7 +9,7 @@ from collections import defaultdict, deque from copy import deepcopy from typing import List, Optional, Union, Tuple -from .pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag +from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag __all__ = [ 'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', @@ -167,10 +167,12 @@ def is_model(model_name): return arch_name in _model_entrypoints -def model_entrypoint(model_name): +def model_entrypoint(model_name, module_filter: Optional[str] = None): """Fetch a model entrypoint for specified model name """ arch_name = get_arch_name(model_name) + if module_filter and arch_name not in _module_to_models.get(module_filter, {}): + raise RuntimeError(f'Model ({model_name} not found in module {module_filter}.') return _model_entrypoints[arch_name] diff --git a/timm/models/beit.py b/timm/models/beit.py index 1f6bf82b..7c4dd14d 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -46,12 +46,13 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model from .vision_transformer import checkpoint_filter_fn +__all__ = ['Beit'] def _cfg(url='', **kwargs): return { diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 3815fa30..c67144cc 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -13,9 +13,9 @@ Consider all of the models definitions here as experimental WIP and likely to ch Hacked together by / copyright Ross Wightman, 2021. """ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from ._builder import build_model_with_cfg +from ._registry import register_model from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks -from .helpers import build_model_with_cfg -from .registry import register_model __all__ = [] diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 1e402629..0e5c9c7f 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -26,18 +26,18 @@ Hacked together by / copyright Ross Wightman, 2021. """ import math from dataclasses import dataclass, field, replace -from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence from functools import partial +from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\ - EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d -from .registry import register_model +from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ + create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] diff --git a/timm/models/cait.py b/timm/models/cait.py index c0892099..15dcd956 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -8,17 +8,16 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W """ # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ -from .registry import register_model - +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn'] diff --git a/timm/models/coat.py b/timm/models/coat.py index c3071a6c..4ed6d8e8 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -7,7 +7,6 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT Modified from timm/models/vision_transformer.py """ -from copy import deepcopy from functools import partial from typing import Tuple, List, Union @@ -16,19 +15,11 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ -from .registry import register_model -from .layers import _assert - - -__all__ = [ - "coat_tiny", - "coat_mini", - "coat_lite_tiny", - "coat_lite_mini", - "coat_lite_small" -] +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['CoaT'] def _cfg_coat(url='', **kwargs): diff --git a/timm/models/convit.py b/timm/models/convit.py index 26849f6e..d117ccdc 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -22,20 +22,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ''' +from functools import partial + import torch import torch.nn as nn -from functools import partial -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .vision_transformer_hybrid import HybridEmbed -from .fx_features import register_notrace_module -import torch -import torch.nn as nn + +__all__ = ['ConViT'] def _cfg(url='', **kwargs): diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index e7e2481a..3a8c6cf5 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -5,9 +5,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import SelectAdaptivePool2d +from timm.layers import SelectAdaptivePool2d +from ._registry import register_model +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq + +__all__ = ['ConvMixer'] def _cfg(url='', **kwargs): diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 36a484b3..eea5782a 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -18,12 +18,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ +from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ create_conv2d, get_act_layer, make_divisible, to_ntuple -from .pretrained import generate_default_cfgs -from .registry import register_model - +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 764eb3fe..908fcf6d 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -24,21 +24,22 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ +from functools import partial +from typing import List from typing import Tuple import torch -import torch.nn as nn -import torch.nn.functional as F import torch.hub -from functools import partial -from typing import List +import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, trunc_normal_, _assert -from .registry import register_model -from .vision_transformer import Mlp, Block +from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model +from .vision_transformer import Block + +__all__ = ['CrossViT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 2c09e7e3..280f929e 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -12,20 +12,18 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage Hacked together by / Copyright 2020 Ross Wightman """ -import collections.abc -from dataclasses import dataclass, field, asdict +from dataclasses import dataclass, asdict from functools import partial -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP -from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible -from .registry import register_model - +from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, MATCH_PREV_GROUP +from ._registry import register_model __all__ = ['CspNet'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/deit.py b/timm/models/deit.py index 3205b024..24fbbe56 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -17,9 +17,11 @@ from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model +__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 1afdfd7b..e731f7b0 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -4,7 +4,6 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool. """ import re from collections import OrderedDict -from functools import partial import torch import torch.nn as nn @@ -13,9 +12,10 @@ import torch.utils.checkpoint as cp from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, MATCH_PREV_GROUP -from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier -from .registry import register_model +from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import MATCH_PREV_GROUP +from ._registry import register_model __all__ = ['DenseNet'] diff --git a/timm/models/dla.py b/timm/models/dla.py index 0ab807c0..204fcb4b 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -13,9 +13,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['DLA'] diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 95159729..87bd918f 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -15,9 +15,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier -from .registry import register_model +from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['DPN'] diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 422d4f2c..d90471fb 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -8,20 +8,20 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt Modifications and additions for timm by / Copyright 2022, Ross Wightman """ import math -import torch from collections import OrderedDict from functools import partial from typing import Tuple -from torch import nn +import torch import torch.nn.functional as F +from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module -from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d -from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .registry import register_model - +from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model __all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 4749d93a..4f33f29a 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -18,9 +18,11 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, trunc_normal_, to_2tuple, Mlp -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 3c0efc96..a1324ae3 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -42,15 +42,15 @@ import torch import torch.nn as nn import torch.nn.functional as F - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ +from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq -from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct -from .registry import register_model +from ._features import FeatureInfo, FeatureHooks +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['EfficientNet', 'EfficientNetFeatures'] diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index fb375e2c..ec9b7e5e 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -28,12 +28,13 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\ +from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ get_attn, get_act_layer, get_norm_layer, _assert -from .registry import register_model -from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply +from ._registry import register_model +from .vision_transformer_relpos import RelPosBias # FIXME move to common location __all__ = ['GlobalContextVit'] diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index e19af88b..492049b9 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -11,13 +11,12 @@ import torch import torch.nn as nn import torch.nn.functional as F - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import SelectAdaptivePool2d, Linear, make_divisible -from .efficientnet_blocks import SqueezeExcite, ConvBnAct -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model - +from timm.layers import SelectAdaptivePool2d, Linear, make_divisible +from ._builder import build_model_with_cfg +from ._efficientnet_blocks import SqueezeExcite, ConvBnAct +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['GhostNet'] diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index a1e73554..2b4131fb 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -5,11 +5,13 @@ by Ross Wightman """ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SEModule -from .registry import register_model +from timm.layers import SEModule +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet, Bottleneck, BasicBlock +__all__ = [] + def _cfg(url='', **kwargs): return { diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index a9c946b2..b487d0fd 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -13,9 +13,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier, get_padding -from .registry import register_model +from timm.layers import create_classifier, get_padding +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['Xception65'] diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py index 132eeab4..d77e642a 100644 --- a/timm/models/hardcorenas.py +++ b/timm/models/hardcorenas.py @@ -3,12 +3,14 @@ from functools import partial import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels -from .helpers import build_model_with_cfg, pretrained_cfg_for_features -from .layers import get_act_fn +from ._builder import build_model_with_cfg +from ._builder import pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels +from ._registry import register_model from .mobilenetv3 import MobileNetV3, MobileNetV3Features -from .registry import register_model + +__all__ = [] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/helpers.py b/timm/models/helpers.py deleted file mode 100644 index 2a5551e0..00000000 --- a/timm/models/helpers.py +++ /dev/null @@ -1,855 +0,0 @@ -""" Model creation / weight loading / state_dict helpers - -Hacked together by / Copyright 2020 Ross Wightman -""" -import collections.abc -import dataclasses -import logging -import math -import os -import re -from collections import OrderedDict, defaultdict -from copy import deepcopy -from itertools import chain -from typing import Any, Callable, Optional, Tuple, Dict, Union - -import torch -import torch.nn as nn -from torch.hub import load_state_dict_from_url -from torch.utils.checkpoint import checkpoint - -from .pretrained import PretrainedCfg -from .features import FeatureListNet, FeatureDictNet, FeatureHookNet -from .fx_features import FeatureGraphNet -from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf -from .layers import Conv2dSame, Linear, BatchNormAct2d -from .registry import get_pretrained_cfg - - -_logger = logging.getLogger(__name__) - - -# Global variables for rarely used pretrained checkpoint download progress and hash check. -# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle. -_DOWNLOAD_PROGRESS = False -_CHECK_HASH = False - - -def clean_state_dict(state_dict): - # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training - cleaned_state_dict = OrderedDict() - for k, v in state_dict.items(): - name = k[7:] if k.startswith('module.') else k - cleaned_state_dict[name] = v - return cleaned_state_dict - - -def load_state_dict(checkpoint_path, use_ema=True): - if checkpoint_path and os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') - state_dict_key = '' - if isinstance(checkpoint, dict): - if use_ema and checkpoint.get('state_dict_ema', None) is not None: - state_dict_key = 'state_dict_ema' - elif use_ema and checkpoint.get('model_ema', None) is not None: - state_dict_key = 'model_ema' - elif 'state_dict' in checkpoint: - state_dict_key = 'state_dict' - elif 'model' in checkpoint: - state_dict_key = 'model' - state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint) - _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) - return state_dict - else: - _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) - raise FileNotFoundError() - - -def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False): - if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): - # numpy checkpoint, try to load via model specific load_pretrained fn - if hasattr(model, 'load_pretrained'): - model.load_pretrained(checkpoint_path) - else: - raise NotImplementedError('Model cannot load numpy checkpoint') - return - state_dict = load_state_dict(checkpoint_path, use_ema) - if remap: - state_dict = remap_checkpoint(model, state_dict) - incompatible_keys = model.load_state_dict(state_dict, strict=strict) - return incompatible_keys - - -def remap_checkpoint(model, state_dict, allow_reshape=True): - """ remap checkpoint by iterating over state dicts in order (ignoring original keys). - This assumes models (and originating state dict) were created with params registered in same order. - """ - out_dict = {} - for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): - assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' - if va.shape != vb.shape: - if allow_reshape: - vb = vb.reshape(va.shape) - else: - assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' - out_dict[ka] = vb - return out_dict - - -def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): - resume_epoch = None - if os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - if log_info: - _logger.info('Restoring model state from checkpoint...') - state_dict = clean_state_dict(checkpoint['state_dict']) - model.load_state_dict(state_dict) - - if optimizer is not None and 'optimizer' in checkpoint: - if log_info: - _logger.info('Restoring optimizer state from checkpoint...') - optimizer.load_state_dict(checkpoint['optimizer']) - - if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: - if log_info: - _logger.info('Restoring AMP loss scaler state from checkpoint...') - loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) - - if 'epoch' in checkpoint: - resume_epoch = checkpoint['epoch'] - if 'version' in checkpoint and checkpoint['version'] > 1: - resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save - - if log_info: - _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) - else: - model.load_state_dict(checkpoint) - if log_info: - _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) - return resume_epoch - else: - _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) - raise FileNotFoundError() - - -def _resolve_pretrained_source(pretrained_cfg): - cfg_source = pretrained_cfg.get('source', '') - pretrained_url = pretrained_cfg.get('url', None) - pretrained_file = pretrained_cfg.get('file', None) - hf_hub_id = pretrained_cfg.get('hf_hub_id', None) - # resolve where to load pretrained weights from - load_from = '' - pretrained_loc = '' - if cfg_source == 'hf-hub' and has_hf_hub(necessary=True): - # hf-hub specified as source via model identifier - load_from = 'hf-hub' - assert hf_hub_id - pretrained_loc = hf_hub_id - else: - # default source == timm or unspecified - if pretrained_file: - load_from = 'file' - pretrained_loc = pretrained_file - elif pretrained_url: - load_from = 'url' - pretrained_loc = pretrained_url - elif hf_hub_id and has_hf_hub(necessary=True): - # hf-hub available as alternate weight source in default_cfg - load_from = 'hf-hub' - pretrained_loc = hf_hub_id - if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None): - # if a filename override is set, return tuple for location w/ (hub_id, filename) - pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] - return load_from, pretrained_loc - - -def set_pretrained_download_progress(enable=True): - """ Set download progress for pretrained weights on/off (globally). """ - global _DOWNLOAD_PROGRESS - _DOWNLOAD_PROGRESS = enable - - -def set_pretrained_check_hash(enable=True): - """ Set hash checking for pretrained weights on/off (globally). """ - global _CHECK_HASH - _CHECK_HASH = enable - - -def load_custom_pretrained( - model: nn.Module, - pretrained_cfg: Optional[Dict] = None, - load_fn: Optional[Callable] = None, -): - r"""Loads a custom (read non .pth) weight file - - Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls - a passed in custom load fun, or the `load_pretrained` model member fn. - - If the object is already present in `model_dir`, it's deserialized and returned. - The default value of `model_dir` is ``/checkpoints`` where - `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. - - Args: - model: The instantiated model to load weights into - pretrained_cfg (dict): Default pretrained model cfg - load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named - 'laod_pretrained' on the model will be called if it exists - """ - pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) - if not pretrained_cfg: - _logger.warning("Invalid pretrained config, cannot load weights.") - return - - load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) - if not load_from: - _logger.warning("No pretrained weights exist for this model. Using random initialization.") - return - if load_from == 'hf-hub': # FIXME - _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") - elif load_from == 'url': - pretrained_loc = download_cached_file( - pretrained_loc, - check_hash=_CHECK_HASH, - progress=_DOWNLOAD_PROGRESS - ) - - if load_fn is not None: - load_fn(model, pretrained_loc) - elif hasattr(model, 'load_pretrained'): - model.load_pretrained(pretrained_loc) - else: - _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") - - -def adapt_input_conv(in_chans, conv_weight): - conv_type = conv_weight.dtype - conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU - O, I, J, K = conv_weight.shape - if in_chans == 1: - if I > 3: - assert conv_weight.shape[1] % 3 == 0 - # For models with space2depth stems - conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) - conv_weight = conv_weight.sum(dim=2, keepdim=False) - else: - conv_weight = conv_weight.sum(dim=1, keepdim=True) - elif in_chans != 3: - if I != 3: - raise NotImplementedError('Weight format not supported by conversion.') - else: - # NOTE this strategy should be better than random init, but there could be other combinations of - # the original RGB input layer weights that'd work better for specific cases. - repeat = int(math.ceil(in_chans / 3)) - conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] - conv_weight *= (3 / float(in_chans)) - conv_weight = conv_weight.to(conv_type) - return conv_weight - - -def load_pretrained( - model: nn.Module, - pretrained_cfg: Optional[Dict] = None, - num_classes: int = 1000, - in_chans: int = 3, - filter_fn: Optional[Callable] = None, - strict: bool = True, -): - """ Load pretrained checkpoint - - Args: - model (nn.Module) : PyTorch model module - pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset - num_classes (int): num_classes for target model - in_chans (int): in_chans for target model - filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) - strict (bool): strict load of checkpoint - - """ - pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) - if not pretrained_cfg: - _logger.warning("Invalid pretrained config, cannot load weights.") - return - - load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) - if load_from == 'file': - _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') - state_dict = load_state_dict(pretrained_loc) - elif load_from == 'url': - _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') - state_dict = load_state_dict_from_url( - pretrained_loc, - map_location='cpu', - progress=_DOWNLOAD_PROGRESS, - check_hash=_CHECK_HASH, - ) - elif load_from == 'hf-hub': - _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') - if isinstance(pretrained_loc, (list, tuple)): - state_dict = load_state_dict_from_hf(*pretrained_loc) - else: - state_dict = load_state_dict_from_hf(pretrained_loc) - else: - _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") - return - - if filter_fn is not None: - # for backwards compat with filter fn that take one arg, try one first, the two - try: - state_dict = filter_fn(state_dict) - except TypeError: - state_dict = filter_fn(state_dict, model) - - input_convs = pretrained_cfg.get('first_conv', None) - if input_convs is not None and in_chans != 3: - if isinstance(input_convs, str): - input_convs = (input_convs,) - for input_conv_name in input_convs: - weight_name = input_conv_name + '.weight' - try: - state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) - _logger.info( - f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') - except NotImplementedError as e: - del state_dict[weight_name] - strict = False - _logger.warning( - f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') - - classifiers = pretrained_cfg.get('classifier', None) - label_offset = pretrained_cfg.get('label_offset', 0) - if classifiers is not None: - if isinstance(classifiers, str): - classifiers = (classifiers,) - if num_classes != pretrained_cfg['num_classes']: - for classifier_name in classifiers: - # completely discard fully connected if model num_classes doesn't match pretrained weights - state_dict.pop(classifier_name + '.weight', None) - state_dict.pop(classifier_name + '.bias', None) - strict = False - elif label_offset > 0: - for classifier_name in classifiers: - # special case for pretrained weights with an extra background class in pretrained weights - classifier_weight = state_dict[classifier_name + '.weight'] - state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] - classifier_bias = state_dict[classifier_name + '.bias'] - state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] - - model.load_state_dict(state_dict, strict=strict) - - -def extract_layer(model, layer): - layer = layer.split('.') - module = model - if hasattr(model, 'module') and layer[0] != 'module': - module = model.module - if not hasattr(model, 'module') and layer[0] == 'module': - layer = layer[1:] - for l in layer: - if hasattr(module, l): - if not l.isdigit(): - module = getattr(module, l) - else: - module = module[int(l)] - else: - return module - return module - - -def set_layer(model, layer, val): - layer = layer.split('.') - module = model - if hasattr(model, 'module') and layer[0] != 'module': - module = model.module - lst_index = 0 - module2 = module - for l in layer: - if hasattr(module2, l): - if not l.isdigit(): - module2 = getattr(module2, l) - else: - module2 = module2[int(l)] - lst_index += 1 - lst_index -= 1 - for l in layer[:lst_index]: - if not l.isdigit(): - module = getattr(module, l) - else: - module = module[int(l)] - l = layer[lst_index] - setattr(module, l, val) - - -def adapt_model_from_string(parent_module, model_string): - separator = '***' - state_dict = {} - lst_shape = model_string.split(separator) - for k in lst_shape: - k = k.split(':') - key = k[0] - shape = k[1][1:-1].split(',') - if shape[0] != '': - state_dict[key] = [int(i) for i in shape] - - new_module = deepcopy(parent_module) - for n, m in parent_module.named_modules(): - old_module = extract_layer(parent_module, n) - if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): - if isinstance(old_module, Conv2dSame): - conv = Conv2dSame - else: - conv = nn.Conv2d - s = state_dict[n + '.weight'] - in_channels = s[1] - out_channels = s[0] - g = 1 - if old_module.groups > 1: - in_channels = out_channels - g = in_channels - new_conv = conv( - in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, - bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, - groups=g, stride=old_module.stride) - set_layer(new_module, n, new_conv) - elif isinstance(old_module, BatchNormAct2d): - new_bn = BatchNormAct2d( - state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, - affine=old_module.affine, track_running_stats=True) - new_bn.drop = old_module.drop - new_bn.act = old_module.act - set_layer(new_module, n, new_bn) - elif isinstance(old_module, nn.BatchNorm2d): - new_bn = nn.BatchNorm2d( - num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, - affine=old_module.affine, track_running_stats=True) - set_layer(new_module, n, new_bn) - elif isinstance(old_module, nn.Linear): - # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? - num_features = state_dict[n + '.weight'][1] - new_fc = Linear( - in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) - set_layer(new_module, n, new_fc) - if hasattr(new_module, 'num_features'): - new_module.num_features = num_features - new_module.eval() - parent_module.eval() - - return new_module - - -def adapt_model_from_file(parent_module, model_variant): - adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') - with open(adapt_file, 'r') as f: - return adapt_model_from_string(parent_module, f.read().strip()) - - -def pretrained_cfg_for_features(pretrained_cfg): - pretrained_cfg = deepcopy(pretrained_cfg) - # remove default pretrained cfg fields that don't have much relevance for feature backbone - to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size? - for tr in to_remove: - pretrained_cfg.pop(tr, None) - return pretrained_cfg - - -def _filter_kwargs(kwargs, names): - if not kwargs or not names: - return - for n in names: - kwargs.pop(n, None) - - -def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter): - """ Update the default_cfg and kwargs before passing to model - - Args: - pretrained_cfg: input pretrained cfg (updated in-place) - kwargs: keyword args passed to model build fn (updated in-place) - kwargs_filter: keyword arg keys that must be removed before model __init__ - """ - # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) - default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') - if pretrained_cfg.get('fixed_input_size', False): - # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size - default_kwarg_names += ('img_size',) - - for n in default_kwarg_names: - # for legacy reasons, model __init__args uses img_size + in_chans as separate args while - # pretrained_cfg has one input_size=(C, H ,W) entry - if n == 'img_size': - input_size = pretrained_cfg.get('input_size', None) - if input_size is not None: - assert len(input_size) == 3 - kwargs.setdefault(n, input_size[-2:]) - elif n == 'in_chans': - input_size = pretrained_cfg.get('input_size', None) - if input_size is not None: - assert len(input_size) == 3 - kwargs.setdefault(n, input_size[0]) - else: - default_val = pretrained_cfg.get(n, None) - if default_val is not None: - kwargs.setdefault(n, pretrained_cfg[n]) - - # Filter keyword args for task specific model variants (some 'features only' models, etc.) - _filter_kwargs(kwargs, names=kwargs_filter) - - -def resolve_pretrained_cfg( - variant: str, - pretrained_cfg=None, - pretrained_cfg_overlay=None, -) -> PretrainedCfg: - model_with_tag = variant - pretrained_tag = None - if pretrained_cfg: - if isinstance(pretrained_cfg, dict): - # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg - pretrained_cfg = PretrainedCfg(**pretrained_cfg) - elif isinstance(pretrained_cfg, str): - pretrained_tag = pretrained_cfg - pretrained_cfg = None - - # fallback to looking up pretrained cfg in model registry by variant identifier - if not pretrained_cfg: - if pretrained_tag: - model_with_tag = '.'.join([variant, pretrained_tag]) - pretrained_cfg = get_pretrained_cfg(model_with_tag) - - if not pretrained_cfg: - _logger.warning( - f"No pretrained configuration specified for {model_with_tag} model. Using a default." - f" Please add a config to the model pretrained_cfg registry or pass explicitly.") - pretrained_cfg = PretrainedCfg() # instance with defaults - - pretrained_cfg_overlay = pretrained_cfg_overlay or {} - if not pretrained_cfg.architecture: - pretrained_cfg_overlay.setdefault('architecture', variant) - pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay) - - return pretrained_cfg - - -def build_model_with_cfg( - model_cls: Callable, - variant: str, - pretrained: bool, - pretrained_cfg: Optional[Dict] = None, - pretrained_cfg_overlay: Optional[Dict] = None, - model_cfg: Optional[Any] = None, - feature_cfg: Optional[Dict] = None, - pretrained_strict: bool = True, - pretrained_filter_fn: Optional[Callable] = None, - kwargs_filter: Optional[Tuple[str]] = None, - **kwargs, -): - """ Build model with specified default_cfg and optional model_cfg - - This helper fn aids in the construction of a model including: - * handling default_cfg and associated pretrained weight loading - * passing through optional model_cfg for models with config based arch spec - * features_only model adaptation - * pruning config / model adaptation - - Args: - model_cls (nn.Module): model class - variant (str): model variant name - pretrained (bool): load pretrained weights - pretrained_cfg (dict): model's pretrained weight/task config - model_cfg (Optional[Dict]): model's architecture config - feature_cfg (Optional[Dict]: feature extraction adapter config - pretrained_strict (bool): load pretrained weights strictly - pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights - kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model - **kwargs: model args passed through to model __init__ - """ - pruned = kwargs.pop('pruned', False) - features = False - feature_cfg = feature_cfg or {} - - # resolve and update model pretrained config and model kwargs - pretrained_cfg = resolve_pretrained_cfg( - variant, - pretrained_cfg=pretrained_cfg, - pretrained_cfg_overlay=pretrained_cfg_overlay - ) - - # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model - pretrained_cfg = pretrained_cfg.to_dict() - - _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) - - # Setup for feature extraction wrapper done at end of this fn - if kwargs.pop('features_only', False): - features = True - feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) - if 'out_indices' in kwargs: - feature_cfg['out_indices'] = kwargs.pop('out_indices') - - # Instantiate the model - if model_cfg is None: - model = model_cls(**kwargs) - else: - model = model_cls(cfg=model_cfg, **kwargs) - model.pretrained_cfg = pretrained_cfg - model.default_cfg = model.pretrained_cfg # alias for backwards compat - - if pruned: - model = adapt_model_from_file(model, variant) - - # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats - num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) - if pretrained: - if pretrained_cfg.get('custom_load', False): - load_custom_pretrained( - model, - pretrained_cfg=pretrained_cfg, - ) - else: - load_pretrained( - model, - pretrained_cfg=pretrained_cfg, - num_classes=num_classes_pretrained, - in_chans=kwargs.get('in_chans', 3), - filter_fn=pretrained_filter_fn, - strict=pretrained_strict, - ) - - # Wrap the model in a feature extraction module if enabled - if features: - feature_cls = FeatureListNet - if 'feature_cls' in feature_cfg: - feature_cls = feature_cfg.pop('feature_cls') - if isinstance(feature_cls, str): - feature_cls = feature_cls.lower() - if 'hook' in feature_cls: - feature_cls = FeatureHookNet - elif feature_cls == 'fx': - feature_cls = FeatureGraphNet - else: - assert False, f'Unknown feature class {feature_cls}' - model = feature_cls(model, **feature_cfg) - model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg - model.default_cfg = model.pretrained_cfg # alias for backwards compat - - return model - - -def model_parameters(model, exclude_head=False): - if exclude_head: - # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering - return [p for p in model.parameters()][:-2] - else: - return model.parameters() - - -def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: - if not depth_first and include_root: - fn(module=module, name=name) - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - fn(module=module, name=name) - return module - - -def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): - if not depth_first and include_root: - yield name, module - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - yield from named_modules( - module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - yield name, module - - -def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False): - if module._parameters and not depth_first and include_root: - yield name, module - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - yield from named_modules_with_params( - module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if module._parameters and depth_first and include_root: - yield name, module - - -MATCH_PREV_GROUP = (99999,) - - -def group_with_matcher( - named_objects, - group_matcher: Union[Dict, Callable], - output_values: bool = False, - reverse: bool = False -): - if isinstance(group_matcher, dict): - # dictionary matcher contains a dict of raw-string regex expr that must be compiled - compiled = [] - for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): - if mspec is None: - continue - # map all matching specifications into 3-tuple (compiled re, prefix, suffix) - if isinstance(mspec, (tuple, list)): - # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) - for sspec in mspec: - compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] - else: - compiled += [(re.compile(mspec), (group_ordinal,), None)] - group_matcher = compiled - - def _get_grouping(name): - if isinstance(group_matcher, (list, tuple)): - for match_fn, prefix, suffix in group_matcher: - r = match_fn.match(name) - if r: - parts = (prefix, r.groups(), suffix) - # map all tuple elem to int for numeric sort, filter out None entries - return tuple(map(float, chain.from_iterable(filter(None, parts)))) - return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal - else: - ord = group_matcher(name) - if not isinstance(ord, collections.abc.Iterable): - return ord, - return tuple(ord) - - # map layers into groups via ordinals (ints or tuples of ints) from matcher - grouping = defaultdict(list) - for k, v in named_objects: - grouping[_get_grouping(k)].append(v if output_values else k) - - # remap to integers - layer_id_to_param = defaultdict(list) - lid = -1 - for k in sorted(filter(lambda x: x is not None, grouping.keys())): - if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: - lid += 1 - layer_id_to_param[lid].extend(grouping[k]) - - if reverse: - assert not output_values, "reverse mapping only sensible for name output" - # output reverse mapping - param_to_layer_id = {} - for lid, lm in layer_id_to_param.items(): - for n in lm: - param_to_layer_id[n] = lid - return param_to_layer_id - - return layer_id_to_param - - -def group_parameters( - module: nn.Module, - group_matcher, - output_values=False, - reverse=False, -): - return group_with_matcher( - module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) - - -def group_modules( - module: nn.Module, - group_matcher, - output_values=False, - reverse=False, -): - return group_with_matcher( - named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) - - -def checkpoint_seq( - functions, - x, - every=1, - flatten=False, - skip_last=False, - preserve_rng_state=True -): - r"""A helper function for checkpointing sequential models. - - Sequential models execute a list of modules/functions in order - (sequentially). Therefore, we can divide such a sequence into segments - and checkpoint each segment. All segments except run in :func:`torch.no_grad` - manner, i.e., not storing the intermediate activations. The inputs of each - checkpointed segment will be saved for re-running the segment in the backward pass. - - See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. - - .. warning:: - Checkpointing currently only supports :func:`torch.autograd.backward` - and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` - is not supported. - - .. warning: - At least one of the inputs needs to have :code:`requires_grad=True` if - grads are needed for model inputs, otherwise the checkpointed part of the - model won't have gradients. - - Args: - functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. - x: A Tensor that is input to :attr:`functions` - every: checkpoint every-n functions (default: 1) - flatten (bool): flatten nn.Sequential of nn.Sequentials - skip_last (bool): skip checkpointing the last function in the sequence if True - preserve_rng_state (bool, optional, default=True): Omit stashing and restoring - the RNG state during each checkpoint. - - Returns: - Output of running :attr:`functions` sequentially on :attr:`*inputs` - - Example: - >>> model = nn.Sequential(...) - >>> input_var = checkpoint_seq(model, input_var, every=2) - """ - def run_function(start, end, functions): - def forward(_x): - for j in range(start, end + 1): - _x = functions[j](_x) - return _x - return forward - - if isinstance(functions, torch.nn.Sequential): - functions = functions.children() - if flatten: - functions = chain.from_iterable(functions) - if not isinstance(functions, (tuple, list)): - functions = tuple(functions) - - num_checkpointed = len(functions) - if skip_last: - num_checkpointed -= 1 - end = -1 - for start in range(0, num_checkpointed, every): - end = min(start + every - 1, num_checkpointed - 1) - x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) - if skip_last: - return run_function(end + 1, len(functions) - 1, functions)(x) - return x - - -def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'): - prefix_is_tuple = isinstance(prefix, tuple) - if isinstance(module_types, str): - if module_types == 'container': - module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict) - else: - module_types = (nn.Sequential,) - for name, module in named_modules: - if depth and isinstance(module, module_types): - yield from flatten_modules( - module.named_children(), - depth - 1, - prefix=(name,) if prefix_is_tuple else name, - module_types=module_types, - ) - else: - if prefix_is_tuple: - name = prefix + (name,) - yield name, module - else: - if prefix: - name = '.'.join([prefix, name]) - yield name, module diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 30860120..338d409e 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -16,12 +16,14 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .features import FeatureInfo -from .helpers import build_model_with_cfg, pretrained_cfg_for_features -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._features import FeatureInfo +from ._registry import register_model from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE +__all__ = ['HighResolutionNet', 'HighResolutionNetFeatures'] # model_registry will add each entrypoint fn to this + _BN_MOMENTUM = 0.1 _logger = logging.getLogger(__name__) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index fa7b8ec8..3006f3d2 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -7,9 +7,10 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, flatten_modules -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import flatten_modules +from ._registry import register_model __all__ = ['InceptionResnetV2'] diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index c70bd608..28794ce6 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -8,9 +8,13 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, flatten_modules -from .registry import register_model -from .layers import trunc_normal_, create_classifier, Linear +from timm.layers import trunc_normal_, create_classifier, Linear +from ._builder import build_model_with_cfg +from ._builder import resolve_pretrained_cfg +from ._manipulate import flatten_modules +from ._registry import register_model + +__all__ = ['InceptionV3', 'InceptionV3Aux'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 5f4e208f..c1559829 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -7,9 +7,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['InceptionV4'] diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 21c641b6..1bfb95d5 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1,44 +1,45 @@ -from .activations import * -from .adaptive_avgmax_pool import \ +# NOTE timm.models.layers is DEPRECATED, please use timm.layers, this is here to reduce breakages in transition +from timm.layers.activations import * +from timm.layers.adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d -from .blur_pool import BlurPool2d -from .classifier import ClassifierHead, create_classifier -from .cond_conv2d import CondConv2d, get_condconv_initializer -from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ +from timm.layers.blur_pool import BlurPool2d +from timm.layers.classifier import ClassifierHead, create_classifier +from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer +from timm.layers.config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ set_layer_config -from .conv2d_same import Conv2dSame, conv2d_same -from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct -from .create_act import create_act_layer, get_act_layer, get_act_fn -from .create_attn import get_attn, create_attn -from .create_conv2d import create_conv2d -from .create_norm import get_norm_layer, create_norm_layer -from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer -from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path -from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn -from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ +from timm.layers.conv2d_same import Conv2dSame, conv2d_same +from timm.layers.conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct +from timm.layers.create_act import create_act_layer, get_act_layer, get_act_fn +from timm.layers.create_attn import get_attn, create_attn +from timm.layers.create_conv2d import create_conv2d +from timm.layers.create_norm import get_norm_layer, create_norm_layer +from timm.layers.create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer +from timm.layers.drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from timm.layers.eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from timm.layers.evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a -from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm -from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d -from .gather_excite import GatherExcite -from .global_context import GlobalContext -from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple -from .inplace_abn import InplaceAbn -from .linear import Linear -from .mixed_conv2d import MixedConv2d -from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp -from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d -from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm -from .padding import get_padding, get_same_padding, pad_same -from .patch_embed import PatchEmbed -from .pool2d_same import AvgPool2dSame, create_pool2d -from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite -from .selective_kernel import SelectiveKernel -from .separable_conv import SeparableConv2d, SeparableConvNormAct -from .space_to_depth import SpaceToDepthModule -from .split_attn import SplitAttn -from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model -from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame -from .test_time_pool import TestTimePoolHead, apply_test_time_pool -from .trace_utils import _assert, _float_to_int -from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ +from timm.layers.fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm +from timm.layers.filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d +from timm.layers.gather_excite import GatherExcite +from timm.layers.global_context import GlobalContext +from timm.layers.helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple +from timm.layers.inplace_abn import InplaceAbn +from timm.layers.linear import Linear +from timm.layers.mixed_conv2d import MixedConv2d +from timm.layers.mlp import Mlp, GluMlp, GatedMlp, ConvMlp +from timm.layers.non_local_attn import NonLocalAttn, BatNonLocalAttn +from timm.layers.norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d +from timm.layers.norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm +from timm.layers.padding import get_padding, get_same_padding, pad_same +from timm.layers.patch_embed import PatchEmbed +from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d +from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from timm.layers.selective_kernel import SelectiveKernel +from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct +from timm.layers.space_to_depth import SpaceToDepthModule +from timm.layers.split_attn import SplitAttn +from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool +from timm.layers.trace_utils import _assert, _float_to_int +from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/models/levit.py b/timm/models/levit.py index cea9f0fc..8dc11309 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -23,8 +23,6 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Modified from # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Copyright 2020 Ross Wightman, Apache-2.0 License -import itertools -from copy import deepcopy from functools import partial from typing import Dict @@ -32,10 +30,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import to_ntuple, get_act_layer -from .vision_transformer import trunc_normal_ -from .registry import register_model +from timm.layers import to_ntuple, get_act_layer, trunc_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['LevitDistilled'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 3f315093..1e2666e5 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -45,17 +45,17 @@ from typing import Callable, Optional, Union, Tuple, List import torch from torch import nn -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq, named_apply -from .fx_features import register_notrace_function -from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm -from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d -from .layers import SelectAdaptivePool2d, create_pool2d -from .layers import to_2tuple, extend_tuple, make_divisible, _assert -from .pretrained import generate_default_cfgs -from .registry import register_model +from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm +from timm.layers import SelectAdaptivePool2d, create_pool2d +from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d +from timm.layers import to_2tuple, extend_tuple, make_divisible, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply, checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location __all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index a77e2eb7..a7825899 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -39,16 +39,18 @@ A thank you to paper authors for releasing code and weights. Hacked together by / Copyright 2021 Ross Wightman """ import math -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model + +__all__ = ['MixerBlock'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index bb72ccb8..cf4f268d 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -14,13 +14,14 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ +from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq -from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer -from .registry import register_model +from ._features import FeatureInfo, FeatureHooks +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['MobileNetV3', 'MobileNetV3Features'] diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index bd5479a7..3d2ae84a 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -14,18 +14,18 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022 # Copyright (C) 2020 Apple Inc. All Rights Reserved. # import math -from typing import Union, Callable, Dict, Tuple, Optional, Sequence +from typing import Callable, Tuple, Optional import torch -from torch import nn import torch.nn.functional as F +from torch import nn +from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups -from .fx_features import register_notrace_module -from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath from .vision_transformer import Block as TransformerBlock -from .helpers import build_model_with_cfg -from .registry import register_model __all__ = [] diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index c5aaa09e..5c0a6650 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -24,10 +24,12 @@ import torch.utils.checkpoint as checkpoint from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg -from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple -from .registry import register_model +from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model + +__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 50db1a3d..0b2178d6 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier -from .registry import register_model +from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['NASNetALarge'] diff --git a/timm/models/nest.py b/timm/models/nest.py index 8692a2b1..c9c6258c 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -25,12 +25,14 @@ import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ -from .layers import _assert -from .layers import create_conv2d, create_pool2d, to_ntuple -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert +from timm.layers import create_conv2d, create_pool2d, to_ntuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model + +__all__ = ['Nest'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 3a45410b..48f91b35 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -16,21 +16,23 @@ Status: Hacked together by / copyright Ross Wightman, 2021. """ -import math -from dataclasses import dataclass, field from collections import OrderedDict -from typing import Tuple, Optional +from dataclasses import dataclass from functools import partial +from typing import Tuple, Optional import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model -from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ +from timm.layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame, \ get_act_layer, get_act_fn, get_attn, make_divisible +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['NormFreeNet', 'NfCfg'] # model_registry will add each entrypoint fn to this def _dcfg(url='', **kwargs): diff --git a/timm/models/pit.py b/timm/models/pit.py index 0f571319..4f40e5e0 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -13,7 +13,6 @@ Modifications for timm by / Copyright 2020 Ross Wightman import math import re -from copy import deepcopy from functools import partial from typing import Tuple @@ -21,12 +20,15 @@ import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import trunc_normal_, to_2tuple -from .registry import register_model +from timm.layers import trunc_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._registry import register_model from .vision_transformer import Block +__all__ = ['PoolingVisionTransformer'] # model_registry will add each entrypoint fn to this + + def _cfg(url='', **kwargs): return { 'url': url, diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 81067845..7291c8fb 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -12,9 +12,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier -from .registry import register_model +from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['PNASNet5Large'] diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index 09359bc8..b4d2d18f 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -19,15 +19,15 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import copy import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['PoolFormer'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index dd3cf690..696a2506 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -24,9 +24,9 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ -from .registry import register_model +from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['PyramidVisionTransformerV2'] diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 0ad7c826..e1cc821b 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -23,10 +23,13 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct -from .layers import get_act_layer, get_norm_act_layer, create_conv2d -from .registry import register_model +from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct +from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model + +__all__ = ['RegNet', 'RegNetCfg'] # model_registry will add each entrypoint fn to this @dataclass diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 6c2fd1bf..4724df2a 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -8,8 +8,8 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .registry import register_model +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet __all__ = [] diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 735b91a2..3b001c7b 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -6,13 +6,12 @@ Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang198 Modified for torchscript compat, and consistency with timm by Ross Wightman """ -import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SplitAttn -from .registry import register_model +from timm.layers import SplitAttn +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet diff --git a/timm/models/resnet.py b/timm/models/resnet.py index d0d98894..50849017 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -15,9 +15,11 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier -from .registry import register_model +from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \ + create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, model_entrypoint __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this @@ -675,6 +677,11 @@ class ResNet(nn.Module): self.init_weights(zero_init_last=zero_init_last) + @staticmethod + def from_pretrained(model_name: str, load_weights=True, **kwargs) -> 'ResNet': + entry_fn = model_entrypoint(model_name, 'resnet') + return entry_fn(pretrained=not load_weights, **kwargs) + @torch.jit.ignore def init_weights(self, zero_init_last=True): for n, m in self.named_modules(): @@ -822,7 +829,7 @@ def resnet50(pretrained=False, **kwargs): @register_model -def resnet50d(pretrained=False, **kwargs): +def resnet50d(pretrained=False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model. """ model_args = dict( diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index b21ef7f5..f8c4298b 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -30,16 +30,19 @@ Original copyright of Google code below, modifications by Ross Wightman, Copyrig # limitations under the License. from collections import OrderedDict # pylint: disable=g-importing-member +from functools import partial import torch import torch.nn as nn -from functools import partial from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq -from .registry import register_model -from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, EvoNorm2dS1, FilterResponseNormTlu2d,\ +from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \ ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv +from ._registry import register_model + +__all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 33e97222..51e8cdc2 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -10,16 +10,20 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe Copyright 2020 Ross Wightman """ -import torch -import torch.nn as nn from functools import partial from math import ceil +import torch +import torch.nn as nn + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule -from .registry import register_model -from .efficientnet_builder import efficientnet_init_weights +from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule +from ._builder import build_model_with_cfg +from ._efficientnet_builder import efficientnet_init_weights +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['ReXNetV1'] # model_registry will add each entrypoint fn to this def _cfg(url=''): diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 1a9ac929..4d40c49a 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -16,9 +16,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/senet.py b/timm/models/senet.py index a9e23ff1..d36e9854 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -19,9 +19,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['SENet'] diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index b1ae92a4..f3f758b9 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -6,7 +6,6 @@ Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2 # Copyright (c) 2022. Yuki Tatsunami # Licensed under the Apache License, Version 2.0 (the "License"); - import math from functools import partial from typing import Tuple @@ -15,9 +14,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT -from .helpers import build_model_with_cfg, named_apply -from .layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed -from .registry import register_model +from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed +from ._builder import build_model_with_cfg +from ._manipulate import named_apply +from ._registry import register_model + +__all__ = ['Sequencer2D'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/sknet.py b/timm/models/sknet.py index fb9f063a..5a29b9a4 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -13,9 +13,9 @@ import math from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SelectiveKernel, ConvNormAct, ConvNormActAa, create_attn -from .registry import register_model +from timm.layers import SelectiveKernel, ConvNormAct, create_attn +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index f2305fb2..5df06d4d 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -17,19 +17,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # -------------------------------------------------------- import logging import math -from functools import partial from typing import Optional import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit +__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 0c9db3dd..efaaa9e9 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -21,10 +21,12 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model + +__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index d143c14c..cf10b39c 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -29,7 +29,6 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W # -------------------------------------------------------- import logging import math -from copy import deepcopy from typing import Tuple, Optional, List, Union, Any, Type import torch @@ -38,11 +37,13 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import DropPath, Mlp, to_2tuple, _assert -from .registry import register_model +from timm.layers import DropPath, Mlp, to_2tuple, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply +from ._registry import register_model +__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 5b72b196..50088baf 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -7,17 +7,18 @@ The official mindspore code is released and available at https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT """ import math + import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import build_model_with_cfg -from timm.models.layers import Mlp, DropPath, trunc_normal_ -from timm.models.layers.helpers import to_2tuple -from timm.models.layers import _assert -from timm.models.registry import register_model -from timm.models.vision_transformer import resize_pos_embed +from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple +from ._builder import build_model_with_cfg +from ._registry import register_model +from .vision_transformer import resize_pos_embed + +__all__ = ['TNT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 2469acd2..83cb0576 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -10,11 +10,11 @@ from collections import OrderedDict import torch import torch.nn as nn -from .helpers import build_model_with_cfg -from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule -from .registry import register_model +from timm.layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule +from ._builder import build_model_with_cfg +from ._registry import register_model -__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] +__all__ = ['TResNet'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/twins.py b/timm/models/twins.py index 0626db37..41944c36 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -12,20 +12,21 @@ Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/li # Written by Xinjie Li, Xiangxiang Chu # -------------------------------------------------------- import math -from copy import deepcopy -from typing import Optional, Tuple +from functools import partial +from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F -from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ -from .fx_features import register_notrace_module -from .registry import register_model +from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .vision_transformer import Attention -from .helpers import build_model_with_cfg + +__all__ = ['Twins'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/vgg.py b/timm/models/vgg.py index caf96517..abe9f8d5 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -5,21 +5,19 @@ timm functionality. Copyright 2021 Ross Wightman """ +from typing import Union, List, Dict, Any, cast + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .fx_features import register_notrace_module -from .layers import ClassifierHead -from .registry import register_model - -__all__ = [ - 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', - 'vgg19_bn', 'vgg19', -] +from timm.layers import ClassifierHead +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model + +__all__ = ['VGG'] def _cfg(url='', **kwargs): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 254a0748..e15ae4a5 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -6,17 +6,15 @@ From original at https://github.com/danczs/Visformer Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman """ -from copy import deepcopy import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier -from .registry import register_model - +from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['Visformer'] diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 4effbed6..3c2ebc29 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -19,10 +19,10 @@ for some einops/einsum fun Hacked together by / Copyright 2020, Ross Wightman """ -import math import logging -from functools import partial +import math from collections import OrderedDict +from functools import partial from typing import Optional import torch @@ -30,12 +30,17 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\ +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ -from .pretrained import generate_default_cfgs -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv +from ._pretrained import generate_default_cfgs +from ._registry import register_model + + +__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this + _logger = logging.getLogger(__name__) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 5e5113d7..cfdd0a0e 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -13,19 +13,18 @@ They were moved here to keep file sizes sane. Hacked together by / Copyright 2020, Ross Wightman """ -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import StdConv2dSame, StdConv2d, to_2tuple -from .pretrained import generate_default_cfgs +from timm.layers import StdConv2dSame, StdConv2d, to_2tuple +from ._pretrained import generate_default_cfgs +from ._registry import register_model from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem -from .registry import register_model -from timm.models.vision_transformer import _create_vision_transformer +from .vision_transformer import _create_vision_transformer def _cfg(url='', **kwargs): diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 52b3ce45..1a7c2f40 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -4,11 +4,9 @@ NOTE: these models are experimental / WIP, expect changes Hacked together by / Copyright 2022, Ross Wightman """ -import math import logging +import math from functools import partial -from collections import OrderedDict -from dataclasses import dataclass from typing import Optional, Tuple import torch @@ -16,10 +14,12 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple -from .registry import register_model +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/volo.py b/timm/models/volo.py index 735453c8..1117995a 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -20,17 +20,19 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman # See the License for the specific language governing permissions and # limitations under the License. import math -import numpy as np +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ -from timm.models.registry import register_model -from timm.models.helpers import build_model_with_cfg +from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['VOLO'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 39d37195..bf0e4f89 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -15,13 +15,15 @@ from typing import List import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath,\ +from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \ create_attn, create_norm_act_layer, get_norm_act_layer +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['VovNet'] # model_registry will add each entrypoint fn to this # model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 & diff --git a/timm/models/xception.py b/timm/models/xception.py index 99d02c46..99e74b46 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -25,9 +25,9 @@ import torch.jit import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['Xception'] diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 6bbce5e6..e3348e64 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -11,10 +11,11 @@ import torch import torch.nn as nn from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer -from .layers.helpers import to_3tuple -from .registry import register_model +from timm.layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer +from timm.layers.helpers import to_3tuple +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['XceptionAligned'] diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 6802fc84..57c11183 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -19,12 +19,14 @@ import torch.nn as nn from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .vision_transformer import _cfg, Mlp -from .registry import register_model -from .layers import DropPath, trunc_normal_, to_2tuple +from timm.layers import DropPath, trunc_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .cait import ClassAttn -from .fx_features import register_notrace_module +from .vision_transformer import Mlp + +__all__ = ['XCiT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 02f0e250..8613a62c 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn import torch.optim as optim -from timm.models.helpers import group_parameters +from timm.models import group_parameters from .adabelief import AdaBelief from .adafactor import Adafactor diff --git a/timm/version.py b/timm/version.py index 0f19999f..0716d38a 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.8.0dev0' +__version__ = '0.8.1dev0' diff --git a/train.py b/train.py index d40ff04b..1276840d 100755 --- a/train.py +++ b/train.py @@ -31,10 +31,9 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm import utils from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \ - LabelSmoothingCrossEntropy -from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ - convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm +from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm +from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy +from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler @@ -82,7 +81,7 @@ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset parameters group = parser.add_argument_group('Dataset parameters') -# Keep this argument outside of the dataset group because it is positional. +# Keep this argument outside the dataset group because it is positional. parser.add_argument('data', nargs='?', metavar='DIR', const=None, help='path to dataset (positional is *deprecated*, use --data-dir)') parser.add_argument('--data-dir', metavar='DIR', diff --git a/validate.py b/validate.py index 6b8222b9..3bbf07cf 100755 --- a/validate.py +++ b/validate.py @@ -8,22 +8,24 @@ canonical PyTorch, standard Python style, and good performance. Repurpose as you Hacked together by Ross Wightman (https://github.com/rwightman) """ import argparse -import os import csv import glob import json -import time import logging -import torch -import torch.nn as nn -import torch.nn.parallel +import os +import time from collections import OrderedDict from contextlib import suppress from functools import partial -from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm +import torch +import torch.nn as nn +import torch.nn.parallel + from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet -from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\ +from timm.layers import apply_test_time_pool, set_fast_norm +from timm.models import create_model, load_checkpoint, is_model, list_models +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ decay_batch_step, check_batch_size_retry try: From cda39b35bd7ac4a8053f422802ba65f88dbb6e3c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 9 Dec 2022 14:39:45 -0800 Subject: [PATCH 2/2] Add a deprecation phase to module re-org --- benchmark.py | 3 ++- timm/models/_builder.py | 4 ++++ timm/models/_factory.py | 3 +++ timm/models/_features.py | 3 +++ timm/models/_features_fx.py | 4 ++++ timm/models/_helpers.py | 2 ++ timm/models/_hub.py | 3 +++ timm/models/_manipulate.py | 3 +++ timm/models/_pretrained.py | 3 +++ timm/models/_prune.py | 2 ++ timm/models/_registry.py | 2 +- timm/models/factory.py | 4 ++++ timm/models/features.py | 4 ++++ timm/models/fx_features.py | 4 ++++ timm/models/helpers.py | 7 +++++++ timm/models/hub.py | 4 ++++ timm/models/layers/__init__.py | 3 +++ timm/models/registry.py | 4 ++++ 18 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 timm/models/factory.py create mode 100644 timm/models/features.py create mode 100644 timm/models/fx_features.py create mode 100644 timm/models/helpers.py create mode 100644 timm/models/hub.py create mode 100644 timm/models/registry.py diff --git a/benchmark.py b/benchmark.py index 04557a7d..95e2cb5a 100755 --- a/benchmark.py +++ b/benchmark.py @@ -19,7 +19,8 @@ import torch.nn as nn import torch.nn.parallel from timm.data import resolve_data_config -from timm.models import create_model, is_model, list_models, set_fast_norm +from timm.layers import set_fast_norm +from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry diff --git a/timm/models/_builder.py b/timm/models/_builder.py index c99c85f6..f634650e 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -23,6 +23,10 @@ _DOWNLOAD_PROGRESS = False _CHECK_HASH = False +__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained', + 'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg'] + + def _resolve_pretrained_source(pretrained_cfg): cfg_source = pretrained_cfg.get('source', '') pretrained_url = pretrained_cfg.get('url', None) diff --git a/timm/models/_factory.py b/timm/models/_factory.py index 2b050ad6..a8092419 100644 --- a/timm/models/_factory.py +++ b/timm/models/_factory.py @@ -9,6 +9,9 @@ from ._hub import load_model_config_from_hf from ._registry import is_model, model_entrypoint +__all__ = ['parse_model_name', 'safe_model_name', 'create_model'] + + def parse_model_name(model_name): if model_name.startswith('hf_hub'): # NOTE for backwards compat, deprecate hf_hub use diff --git a/timm/models/_features.py b/timm/models/_features.py index 0bc46419..59b080cd 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -17,6 +17,9 @@ import torch import torch.nn as nn +__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] + + class FeatureInfo: def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index 2d4a33c2..10670a1d 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -35,6 +35,10 @@ except ImportError: pass +__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor', + 'FeatureGraphNet', 'GraphExtractNet'] + + def register_notrace_module(module: Type[nn.Module]): """ Any module not under timm.models.layers should get this decorator if we don't want to trace through it. diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py index 2856842d..995292aa 100644 --- a/timm/models/_helpers.py +++ b/timm/models/_helpers.py @@ -12,6 +12,8 @@ import timm.models._builder _logger = logging.getLogger(__name__) +__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint'] + def clean_state_dict(state_dict): # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 2a87ae7e..e6b7d558 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -31,6 +31,9 @@ except ImportError: _logger = logging.getLogger(__name__) +__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf', + 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub'] + def get_cache_dir(child_dir=''): """ diff --git a/timm/models/_manipulate.py b/timm/models/_manipulate.py index 82a922a2..192979fc 100644 --- a/timm/models/_manipulate.py +++ b/timm/models/_manipulate.py @@ -9,6 +9,9 @@ import torch from torch import nn as nn from torch.utils.checkpoint import checkpoint +__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv', + 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq'] + def model_parameters(model, exclude_head=False): if exclude_head: diff --git a/timm/models/_pretrained.py b/timm/models/_pretrained.py index 60f38fd4..c422dab7 100644 --- a/timm/models/_pretrained.py +++ b/timm/models/_pretrained.py @@ -4,6 +4,9 @@ from dataclasses import dataclass, field, replace, asdict from typing import Any, Deque, Dict, Tuple, Optional, Union +__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs'] + + @dataclass class PretrainedCfg: """ diff --git a/timm/models/_prune.py b/timm/models/_prune.py index 0d744e40..4e744dec 100644 --- a/timm/models/_prune.py +++ b/timm/models/_prune.py @@ -5,6 +5,8 @@ from torch import nn as nn from timm.layers import Conv2dSame, BatchNormAct2d, Linear +__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file'] + def extract_layer(model, layer): layer = layer.split('.') diff --git a/timm/models/_registry.py b/timm/models/_registry.py index 97c8fd59..fc7b3437 100644 --- a/timm/models/_registry.py +++ b/timm/models/_registry.py @@ -12,7 +12,7 @@ from typing import List, Optional, Union, Tuple from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag __all__ = [ - 'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', + 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] _module_to_models = defaultdict(set) # dict of sets to check membership of model in module diff --git a/timm/models/factory.py b/timm/models/factory.py new file mode 100644 index 00000000..0ae83dc0 --- /dev/null +++ b/timm/models/factory.py @@ -0,0 +1,4 @@ +from ._factory import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/features.py b/timm/models/features.py new file mode 100644 index 00000000..25605d99 --- /dev/null +++ b/timm/models/features.py @@ -0,0 +1,4 @@ +from ._features import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py new file mode 100644 index 00000000..0ff3a18b --- /dev/null +++ b/timm/models/fx_features.py @@ -0,0 +1,4 @@ +from ._features_fx import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/helpers.py b/timm/models/helpers.py new file mode 100644 index 00000000..6bc82eb8 --- /dev/null +++ b/timm/models/helpers.py @@ -0,0 +1,7 @@ +from ._builder import * +from ._helpers import * +from ._manipulate import * +from ._prune import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/hub.py b/timm/models/hub.py new file mode 100644 index 00000000..074ca025 --- /dev/null +++ b/timm/models/hub.py @@ -0,0 +1,4 @@ +from _hub import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 1bfb95d5..97e70563 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -43,3 +43,6 @@ from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, Scal from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool from timm.layers.trace_utils import _assert, _float_to_int from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/registry.py b/timm/models/registry.py new file mode 100644 index 00000000..58e2e1f4 --- /dev/null +++ b/timm/models/registry.py @@ -0,0 +1,4 @@ +from ._registry import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)