From 927f031293a30afb940fff0bee34b85d9c059b0e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 6 Dec 2022 15:00:06 -0800 Subject: [PATCH 1/9] Major module / path restructure, timm.models.layers -> timm.layers, add _ prefix to all non model modules in timm.models --- avg_checkpoints.py | 2 +- clean_checkpoint.py | 2 +- hubconf.py | 5 +- inference.py | 9 +- tests/test_layers.py | 5 +- tests/test_models.py | 2 +- timm/__init__.py | 2 +- timm/data/readers/class_map.py | 3 +- timm/layers/__init__.py | 44 + timm/{models => }/layers/activations.py | 0 timm/{models => }/layers/activations_jit.py | 0 timm/{models => }/layers/activations_me.py | 0 .../layers/adaptive_avgmax_pool.py | 0 timm/{models => }/layers/attention_pool2d.py | 0 timm/{models => }/layers/blur_pool.py | 0 timm/{models => }/layers/bottleneck_attn.py | 0 timm/{models => }/layers/cbam.py | 0 timm/{models => }/layers/classifier.py | 0 timm/{models => }/layers/cond_conv2d.py | 0 timm/{models => }/layers/config.py | 0 timm/{models => }/layers/conv2d_same.py | 0 timm/{models => }/layers/conv_bn_act.py | 0 timm/{models => }/layers/create_act.py | 0 timm/{models => }/layers/create_attn.py | 0 timm/{models => }/layers/create_conv2d.py | 0 timm/{models => }/layers/create_norm.py | 0 timm/{models => }/layers/create_norm_act.py | 0 timm/{models => }/layers/drop.py | 0 timm/{models => }/layers/eca.py | 0 timm/{models => }/layers/evo_norm.py | 0 timm/{models => }/layers/fast_norm.py | 0 .../layers/filter_response_norm.py | 0 timm/{models => }/layers/gather_excite.py | 0 timm/{models => }/layers/global_context.py | 0 timm/{models => }/layers/halo_attn.py | 0 timm/{models => }/layers/helpers.py | 0 timm/{models => }/layers/inplace_abn.py | 0 timm/{models => }/layers/lambda_layer.py | 0 timm/{models => }/layers/linear.py | 0 timm/{models => }/layers/median_pool.py | 0 timm/{models => }/layers/mixed_conv2d.py | 0 timm/{models => }/layers/ml_decoder.py | 0 timm/{models => }/layers/mlp.py | 0 timm/{models => }/layers/non_local_attn.py | 0 timm/{models => }/layers/norm.py | 0 timm/{models => }/layers/norm_act.py | 0 timm/{models => }/layers/padding.py | 0 timm/{models => }/layers/patch_embed.py | 0 timm/{models => }/layers/pool2d_same.py | 0 timm/{models => }/layers/pos_embed.py | 0 timm/{models => }/layers/selective_kernel.py | 0 timm/{models => }/layers/separable_conv.py | 0 timm/{models => }/layers/space_to_depth.py | 0 timm/{models => }/layers/split_attn.py | 0 timm/{models => }/layers/split_batchnorm.py | 0 timm/{models => }/layers/squeeze_excite.py | 0 timm/{models => }/layers/std_conv.py | 0 timm/{models => }/layers/test_time_pool.py | 0 timm/{models => }/layers/trace_utils.py | 0 timm/{models => }/layers/weight_init.py | 0 timm/models/__init__.py | 22 +- timm/models/_builder.py | 395 ++++++++ ...tnet_blocks.py => _efficientnet_blocks.py} | 3 +- ...et_builder.py => _efficientnet_builder.py} | 4 +- timm/models/{factory.py => _factory.py} | 10 +- timm/models/{features.py => _features.py} | 0 .../{fx_features.py => _features_fx.py} | 10 +- timm/models/_helpers.py | 113 +++ timm/models/{hub.py => _hub.py} | 2 +- timm/models/_manipulate.py | 255 ++++++ timm/models/{pretrained.py => _pretrained.py} | 0 timm/models/_prune.py | 111 +++ .../ecaresnet101d_pruned.txt | 0 .../ecaresnet50d_pruned.txt | 0 .../efficientnet_b1_pruned.txt | 0 .../efficientnet_b2_pruned.txt | 0 .../efficientnet_b3_pruned.txt | 0 timm/models/{registry.py => _registry.py} | 6 +- timm/models/beit.py | 9 +- timm/models/byoanet.py | 4 +- timm/models/byobnet.py | 12 +- timm/models/cait.py | 9 +- timm/models/coat.py | 19 +- timm/models/convit.py | 16 +- timm/models/convmixer.py | 9 +- timm/models/convnext.py | 10 +- timm/models/crossvit.py | 19 +- timm/models/cspnet.py | 14 +- timm/models/deit.py | 6 +- timm/models/densenet.py | 8 +- timm/models/dla.py | 6 +- timm/models/dpn.py | 6 +- timm/models/edgenext.py | 14 +- timm/models/efficientformer.py | 8 +- timm/models/efficientnet.py | 14 +- timm/models/gcvit.py | 11 +- timm/models/ghostnet.py | 11 +- timm/models/gluon_resnet.py | 8 +- timm/models/gluon_xception.py | 6 +- timm/models/hardcorenas.py | 12 +- timm/models/helpers.py | 855 ------------------ timm/models/hrnet.py | 10 +- timm/models/inception_resnet_v2.py | 7 +- timm/models/inception_v3.py | 10 +- timm/models/inception_v4.py | 6 +- timm/models/layers/__init__.py | 83 +- timm/models/levit.py | 12 +- timm/models/maxxvit.py | 18 +- timm/models/mlp_mixer.py | 10 +- timm/models/mobilenetv3.py | 13 +- timm/models/mobilevit.py | 12 +- timm/models/mvitv2.py | 10 +- timm/models/nasnet.py | 6 +- timm/models/nest.py | 14 +- timm/models/nfnet.py | 16 +- timm/models/pit.py | 10 +- timm/models/pnasnet.py | 6 +- timm/models/poolformer.py | 10 +- timm/models/pvt_v2.py | 6 +- timm/models/regnet.py | 11 +- timm/models/res2net.py | 4 +- timm/models/resnest.py | 7 +- timm/models/resnet.py | 15 +- timm/models/resnetv2.py | 11 +- timm/models/rexnet.py | 16 +- timm/models/selecsls.py | 6 +- timm/models/senet.py | 6 +- timm/models/sequencer.py | 10 +- timm/models/sknet.py | 6 +- timm/models/swin_transformer.py | 11 +- timm/models/swin_transformer_v2.py | 10 +- timm/models/swin_transformer_v2_cr.py | 11 +- timm/models/tnt.py | 13 +- timm/models/tresnet.py | 8 +- timm/models/twins.py | 15 +- timm/models/vgg.py | 18 +- timm/models/visformer.py | 10 +- timm/models/vision_transformer.py | 19 +- timm/models/vision_transformer_hybrid.py | 9 +- timm/models/vision_transformer_relpos.py | 14 +- timm/models/volo.py | 10 +- timm/models/vovnet.py | 10 +- timm/models/xception.py | 6 +- timm/models/xception_aligned.py | 9 +- timm/models/xcit.py | 12 +- timm/optim/optim_factory.py | 2 +- timm/version.py | 2 +- train.py | 9 +- validate.py | 16 +- 149 files changed, 1387 insertions(+), 1269 deletions(-) create mode 100644 timm/layers/__init__.py rename timm/{models => }/layers/activations.py (100%) rename timm/{models => }/layers/activations_jit.py (100%) rename timm/{models => }/layers/activations_me.py (100%) rename timm/{models => }/layers/adaptive_avgmax_pool.py (100%) rename timm/{models => }/layers/attention_pool2d.py (100%) rename timm/{models => }/layers/blur_pool.py (100%) rename timm/{models => }/layers/bottleneck_attn.py (100%) rename timm/{models => }/layers/cbam.py (100%) rename timm/{models => }/layers/classifier.py (100%) rename timm/{models => }/layers/cond_conv2d.py (100%) rename timm/{models => }/layers/config.py (100%) rename timm/{models => }/layers/conv2d_same.py (100%) rename timm/{models => }/layers/conv_bn_act.py (100%) rename timm/{models => }/layers/create_act.py (100%) rename timm/{models => }/layers/create_attn.py (100%) rename timm/{models => }/layers/create_conv2d.py (100%) rename timm/{models => }/layers/create_norm.py (100%) rename timm/{models => }/layers/create_norm_act.py (100%) rename timm/{models => }/layers/drop.py (100%) rename timm/{models => }/layers/eca.py (100%) rename timm/{models => }/layers/evo_norm.py (100%) rename timm/{models => }/layers/fast_norm.py (100%) rename timm/{models => }/layers/filter_response_norm.py (100%) rename timm/{models => }/layers/gather_excite.py (100%) rename timm/{models => }/layers/global_context.py (100%) rename timm/{models => }/layers/halo_attn.py (100%) rename timm/{models => }/layers/helpers.py (100%) rename timm/{models => }/layers/inplace_abn.py (100%) rename timm/{models => }/layers/lambda_layer.py (100%) rename timm/{models => }/layers/linear.py (100%) rename timm/{models => }/layers/median_pool.py (100%) rename timm/{models => }/layers/mixed_conv2d.py (100%) rename timm/{models => }/layers/ml_decoder.py (100%) rename timm/{models => }/layers/mlp.py (100%) rename timm/{models => }/layers/non_local_attn.py (100%) rename timm/{models => }/layers/norm.py (100%) rename timm/{models => }/layers/norm_act.py (100%) rename timm/{models => }/layers/padding.py (100%) rename timm/{models => }/layers/patch_embed.py (100%) rename timm/{models => }/layers/pool2d_same.py (100%) rename timm/{models => }/layers/pos_embed.py (100%) rename timm/{models => }/layers/selective_kernel.py (100%) rename timm/{models => }/layers/separable_conv.py (100%) rename timm/{models => }/layers/space_to_depth.py (100%) rename timm/{models => }/layers/split_attn.py (100%) rename timm/{models => }/layers/split_batchnorm.py (100%) rename timm/{models => }/layers/squeeze_excite.py (100%) rename timm/{models => }/layers/std_conv.py (100%) rename timm/{models => }/layers/test_time_pool.py (100%) rename timm/{models => }/layers/trace_utils.py (100%) rename timm/{models => }/layers/weight_init.py (100%) create mode 100644 timm/models/_builder.py rename timm/models/{efficientnet_blocks.py => _efficientnet_blocks.py} (99%) rename timm/models/{efficientnet_builder.py => _efficientnet_builder.py} (99%) rename timm/models/{factory.py => _factory.py} (94%) rename timm/models/{features.py => _features.py} (100%) rename timm/models/{fx_features.py => _features_fx.py} (93%) create mode 100644 timm/models/_helpers.py rename timm/models/{hub.py => _hub.py} (99%) create mode 100644 timm/models/_manipulate.py rename timm/models/{pretrained.py => _pretrained.py} (100%) create mode 100644 timm/models/_prune.py rename timm/models/{pruned => _pruned}/ecaresnet101d_pruned.txt (100%) rename timm/models/{pruned => _pruned}/ecaresnet50d_pruned.txt (100%) rename timm/models/{pruned => _pruned}/efficientnet_b1_pruned.txt (100%) rename timm/models/{pruned => _pruned}/efficientnet_b2_pruned.txt (100%) rename timm/models/{pruned => _pruned}/efficientnet_b3_pruned.txt (100%) rename timm/models/{registry.py => _registry.py} (95%) delete mode 100644 timm/models/helpers.py diff --git a/avg_checkpoints.py b/avg_checkpoints.py index ea8bbe84..83af5bbd 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -16,7 +16,7 @@ import argparse import os import glob import hashlib -from timm.models.helpers import load_state_dict +from timm.models import load_state_dict parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser.add_argument('--input', default='', type=str, metavar='PATH', diff --git a/clean_checkpoint.py b/clean_checkpoint.py index 8ec892b2..17c270db 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -13,7 +13,7 @@ import os import hashlib import shutil from collections import OrderedDict -from timm.models.helpers import load_state_dict +from timm.models import load_state_dict parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', diff --git a/hubconf.py b/hubconf.py index 70fed79a..6b2061ea 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,4 +1,3 @@ dependencies = ['torch'] -from timm.models import registry - -globals().update(registry._model_entrypoints) +import timm +globals().update(timm.models._registry._model_entrypoints) diff --git a/inference.py b/inference.py index bc794840..1509b323 100755 --- a/inference.py +++ b/inference.py @@ -5,11 +5,11 @@ An example inference script that outputs top-k class ids for images in a folder Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ -import os -import time import argparse import json import logging +import os +import time from contextlib import suppress from functools import partial @@ -17,12 +17,11 @@ import numpy as np import pandas as pd import torch -from timm.models import create_model, apply_test_time_pool, load_checkpoint from timm.data import create_dataset, create_loader, resolve_data_config +from timm.layers import apply_test_time_pool +from timm.models import create_model from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser - - try: from apex import amp has_apex = True diff --git a/tests/test_layers.py b/tests/test_layers.py index 508a6aae..da061870 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,10 +1,7 @@ -import pytest import torch import torch.nn as nn -import platform -import os -from timm.models.layers import create_act_layer, get_act_layer, set_layer_config +from timm.layers import create_act_layer, set_layer_config class MLP(nn.Module): diff --git a/tests/test_models.py b/tests/test_models.py index dd1330eb..d6c0052f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,7 +14,7 @@ except ImportError: import timm from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value -from timm.models.fx_features import _leaf_modules, _autowrap_functions +from timm.models._features_fx import _leaf_modules, _autowrap_functions if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests diff --git a/timm/__init__.py b/timm/__init__.py index faf34dbc..3d38cdb9 100644 --- a/timm/__init__.py +++ b/timm/__init__.py @@ -1,4 +1,4 @@ from .version import __version__ +from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \ - is_scriptable, is_exportable, set_scriptable, set_exportable, \ is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/data/readers/class_map.py b/timm/data/readers/class_map.py index 6cf3f57e..885be6e2 100644 --- a/timm/data/readers/class_map.py +++ b/timm/data/readers/class_map.py @@ -1,6 +1,7 @@ import os import pickle + def load_class_map(map_or_filename, root=''): if isinstance(map_or_filename, dict): assert dict, 'class_map dict must be non-empty' @@ -14,7 +15,7 @@ def load_class_map(map_or_filename, root=''): with open(class_map_path) as f: class_to_idx = {v.strip(): k for k, v in enumerate(f)} elif class_map_ext == '.pkl': - with open(class_map_path,'rb') as f: + with open(class_map_path, 'rb') as f: class_to_idx = pickle.load(f) else: assert False, f'Unsupported class map file extension ({class_map_ext}).' diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py new file mode 100644 index 00000000..21c641b6 --- /dev/null +++ b/timm/layers/__init__.py @@ -0,0 +1,44 @@ +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .blur_pool import BlurPool2d +from .classifier import ClassifierHead, create_classifier +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ + set_layer_config +from .conv2d_same import Conv2dSame, conv2d_same +from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct +from .create_act import create_act_layer, get_act_layer, get_act_fn +from .create_attn import get_attn, create_attn +from .create_conv2d import create_conv2d +from .create_norm import get_norm_layer, create_norm_layer +from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ + EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm +from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple +from .inplace_abn import InplaceAbn +from .linear import Linear +from .mixed_conv2d import MixedConv2d +from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d +from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm +from .padding import get_padding, get_same_padding, pad_same +from .patch_embed import PatchEmbed +from .pool2d_same import AvgPool2dSame, create_pool2d +from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from .selective_kernel import SelectiveKernel +from .separable_conv import SeparableConv2d, SeparableConvNormAct +from .space_to_depth import SpaceToDepthModule +from .split_attn import SplitAttn +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .trace_utils import _assert, _float_to_int +from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/activations.py b/timm/layers/activations.py similarity index 100% rename from timm/models/layers/activations.py rename to timm/layers/activations.py diff --git a/timm/models/layers/activations_jit.py b/timm/layers/activations_jit.py similarity index 100% rename from timm/models/layers/activations_jit.py rename to timm/layers/activations_jit.py diff --git a/timm/models/layers/activations_me.py b/timm/layers/activations_me.py similarity index 100% rename from timm/models/layers/activations_me.py rename to timm/layers/activations_me.py diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/layers/adaptive_avgmax_pool.py similarity index 100% rename from timm/models/layers/adaptive_avgmax_pool.py rename to timm/layers/adaptive_avgmax_pool.py diff --git a/timm/models/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py similarity index 100% rename from timm/models/layers/attention_pool2d.py rename to timm/layers/attention_pool2d.py diff --git a/timm/models/layers/blur_pool.py b/timm/layers/blur_pool.py similarity index 100% rename from timm/models/layers/blur_pool.py rename to timm/layers/blur_pool.py diff --git a/timm/models/layers/bottleneck_attn.py b/timm/layers/bottleneck_attn.py similarity index 100% rename from timm/models/layers/bottleneck_attn.py rename to timm/layers/bottleneck_attn.py diff --git a/timm/models/layers/cbam.py b/timm/layers/cbam.py similarity index 100% rename from timm/models/layers/cbam.py rename to timm/layers/cbam.py diff --git a/timm/models/layers/classifier.py b/timm/layers/classifier.py similarity index 100% rename from timm/models/layers/classifier.py rename to timm/layers/classifier.py diff --git a/timm/models/layers/cond_conv2d.py b/timm/layers/cond_conv2d.py similarity index 100% rename from timm/models/layers/cond_conv2d.py rename to timm/layers/cond_conv2d.py diff --git a/timm/models/layers/config.py b/timm/layers/config.py similarity index 100% rename from timm/models/layers/config.py rename to timm/layers/config.py diff --git a/timm/models/layers/conv2d_same.py b/timm/layers/conv2d_same.py similarity index 100% rename from timm/models/layers/conv2d_same.py rename to timm/layers/conv2d_same.py diff --git a/timm/models/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py similarity index 100% rename from timm/models/layers/conv_bn_act.py rename to timm/layers/conv_bn_act.py diff --git a/timm/models/layers/create_act.py b/timm/layers/create_act.py similarity index 100% rename from timm/models/layers/create_act.py rename to timm/layers/create_act.py diff --git a/timm/models/layers/create_attn.py b/timm/layers/create_attn.py similarity index 100% rename from timm/models/layers/create_attn.py rename to timm/layers/create_attn.py diff --git a/timm/models/layers/create_conv2d.py b/timm/layers/create_conv2d.py similarity index 100% rename from timm/models/layers/create_conv2d.py rename to timm/layers/create_conv2d.py diff --git a/timm/models/layers/create_norm.py b/timm/layers/create_norm.py similarity index 100% rename from timm/models/layers/create_norm.py rename to timm/layers/create_norm.py diff --git a/timm/models/layers/create_norm_act.py b/timm/layers/create_norm_act.py similarity index 100% rename from timm/models/layers/create_norm_act.py rename to timm/layers/create_norm_act.py diff --git a/timm/models/layers/drop.py b/timm/layers/drop.py similarity index 100% rename from timm/models/layers/drop.py rename to timm/layers/drop.py diff --git a/timm/models/layers/eca.py b/timm/layers/eca.py similarity index 100% rename from timm/models/layers/eca.py rename to timm/layers/eca.py diff --git a/timm/models/layers/evo_norm.py b/timm/layers/evo_norm.py similarity index 100% rename from timm/models/layers/evo_norm.py rename to timm/layers/evo_norm.py diff --git a/timm/models/layers/fast_norm.py b/timm/layers/fast_norm.py similarity index 100% rename from timm/models/layers/fast_norm.py rename to timm/layers/fast_norm.py diff --git a/timm/models/layers/filter_response_norm.py b/timm/layers/filter_response_norm.py similarity index 100% rename from timm/models/layers/filter_response_norm.py rename to timm/layers/filter_response_norm.py diff --git a/timm/models/layers/gather_excite.py b/timm/layers/gather_excite.py similarity index 100% rename from timm/models/layers/gather_excite.py rename to timm/layers/gather_excite.py diff --git a/timm/models/layers/global_context.py b/timm/layers/global_context.py similarity index 100% rename from timm/models/layers/global_context.py rename to timm/layers/global_context.py diff --git a/timm/models/layers/halo_attn.py b/timm/layers/halo_attn.py similarity index 100% rename from timm/models/layers/halo_attn.py rename to timm/layers/halo_attn.py diff --git a/timm/models/layers/helpers.py b/timm/layers/helpers.py similarity index 100% rename from timm/models/layers/helpers.py rename to timm/layers/helpers.py diff --git a/timm/models/layers/inplace_abn.py b/timm/layers/inplace_abn.py similarity index 100% rename from timm/models/layers/inplace_abn.py rename to timm/layers/inplace_abn.py diff --git a/timm/models/layers/lambda_layer.py b/timm/layers/lambda_layer.py similarity index 100% rename from timm/models/layers/lambda_layer.py rename to timm/layers/lambda_layer.py diff --git a/timm/models/layers/linear.py b/timm/layers/linear.py similarity index 100% rename from timm/models/layers/linear.py rename to timm/layers/linear.py diff --git a/timm/models/layers/median_pool.py b/timm/layers/median_pool.py similarity index 100% rename from timm/models/layers/median_pool.py rename to timm/layers/median_pool.py diff --git a/timm/models/layers/mixed_conv2d.py b/timm/layers/mixed_conv2d.py similarity index 100% rename from timm/models/layers/mixed_conv2d.py rename to timm/layers/mixed_conv2d.py diff --git a/timm/models/layers/ml_decoder.py b/timm/layers/ml_decoder.py similarity index 100% rename from timm/models/layers/ml_decoder.py rename to timm/layers/ml_decoder.py diff --git a/timm/models/layers/mlp.py b/timm/layers/mlp.py similarity index 100% rename from timm/models/layers/mlp.py rename to timm/layers/mlp.py diff --git a/timm/models/layers/non_local_attn.py b/timm/layers/non_local_attn.py similarity index 100% rename from timm/models/layers/non_local_attn.py rename to timm/layers/non_local_attn.py diff --git a/timm/models/layers/norm.py b/timm/layers/norm.py similarity index 100% rename from timm/models/layers/norm.py rename to timm/layers/norm.py diff --git a/timm/models/layers/norm_act.py b/timm/layers/norm_act.py similarity index 100% rename from timm/models/layers/norm_act.py rename to timm/layers/norm_act.py diff --git a/timm/models/layers/padding.py b/timm/layers/padding.py similarity index 100% rename from timm/models/layers/padding.py rename to timm/layers/padding.py diff --git a/timm/models/layers/patch_embed.py b/timm/layers/patch_embed.py similarity index 100% rename from timm/models/layers/patch_embed.py rename to timm/layers/patch_embed.py diff --git a/timm/models/layers/pool2d_same.py b/timm/layers/pool2d_same.py similarity index 100% rename from timm/models/layers/pool2d_same.py rename to timm/layers/pool2d_same.py diff --git a/timm/models/layers/pos_embed.py b/timm/layers/pos_embed.py similarity index 100% rename from timm/models/layers/pos_embed.py rename to timm/layers/pos_embed.py diff --git a/timm/models/layers/selective_kernel.py b/timm/layers/selective_kernel.py similarity index 100% rename from timm/models/layers/selective_kernel.py rename to timm/layers/selective_kernel.py diff --git a/timm/models/layers/separable_conv.py b/timm/layers/separable_conv.py similarity index 100% rename from timm/models/layers/separable_conv.py rename to timm/layers/separable_conv.py diff --git a/timm/models/layers/space_to_depth.py b/timm/layers/space_to_depth.py similarity index 100% rename from timm/models/layers/space_to_depth.py rename to timm/layers/space_to_depth.py diff --git a/timm/models/layers/split_attn.py b/timm/layers/split_attn.py similarity index 100% rename from timm/models/layers/split_attn.py rename to timm/layers/split_attn.py diff --git a/timm/models/layers/split_batchnorm.py b/timm/layers/split_batchnorm.py similarity index 100% rename from timm/models/layers/split_batchnorm.py rename to timm/layers/split_batchnorm.py diff --git a/timm/models/layers/squeeze_excite.py b/timm/layers/squeeze_excite.py similarity index 100% rename from timm/models/layers/squeeze_excite.py rename to timm/layers/squeeze_excite.py diff --git a/timm/models/layers/std_conv.py b/timm/layers/std_conv.py similarity index 100% rename from timm/models/layers/std_conv.py rename to timm/layers/std_conv.py diff --git a/timm/models/layers/test_time_pool.py b/timm/layers/test_time_pool.py similarity index 100% rename from timm/models/layers/test_time_pool.py rename to timm/layers/test_time_pool.py diff --git a/timm/models/layers/trace_utils.py b/timm/layers/trace_utils.py similarity index 100% rename from timm/models/layers/trace_utils.py rename to timm/layers/trace_utils.py diff --git a/timm/models/layers/weight_init.py b/timm/layers/weight_init.py similarity index 100% rename from timm/models/layers/weight_init.py rename to timm/layers/weight_init.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 301186dd..5ecc8915 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -64,12 +64,18 @@ from .xception import * from .xception_aligned import * from .xcit import * -from .factory import create_model, parse_model_name, safe_model_name -from .helpers import load_checkpoint, resume_checkpoint, model_parameters -from .layers import TestTimePoolHead, apply_test_time_pool -from .layers import convert_splitbn_model, convert_sync_batchnorm -from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit -from .layers import set_fast_norm -from .pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag -from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\ +from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \ + set_pretrained_download_progress, set_pretrained_check_hash +from ._factory import create_model, parse_model_name, safe_model_name +from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet +from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \ + register_notrace_module, register_notrace_function +from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint +from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub +from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \ + group_modules, group_parameters, checkpoint_seq, adapt_input_conv +from ._pretrained import PretrainedCfg, DefaultCfg, \ + filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag +from ._prune import adapt_model_from_string +from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \ is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/models/_builder.py b/timm/models/_builder.py new file mode 100644 index 00000000..c99c85f6 --- /dev/null +++ b/timm/models/_builder.py @@ -0,0 +1,395 @@ +import dataclasses +import logging +from copy import deepcopy +from typing import Optional, Dict, Callable, Any, Tuple + +from torch import nn as nn +from torch.hub import load_state_dict_from_url + +from timm.models._features import FeatureListNet, FeatureHookNet +from timm.models._features_fx import FeatureGraphNet +from timm.models._helpers import load_state_dict +from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf +from timm.models._manipulate import adapt_input_conv +from timm.models._pretrained import PretrainedCfg +from timm.models._prune import adapt_model_from_file +from timm.models._registry import get_pretrained_cfg + +_logger = logging.getLogger(__name__) + +# Global variables for rarely used pretrained checkpoint download progress and hash check. +# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle. +_DOWNLOAD_PROGRESS = False +_CHECK_HASH = False + + +def _resolve_pretrained_source(pretrained_cfg): + cfg_source = pretrained_cfg.get('source', '') + pretrained_url = pretrained_cfg.get('url', None) + pretrained_file = pretrained_cfg.get('file', None) + hf_hub_id = pretrained_cfg.get('hf_hub_id', None) + # resolve where to load pretrained weights from + load_from = '' + pretrained_loc = '' + if cfg_source == 'hf-hub' and has_hf_hub(necessary=True): + # hf-hub specified as source via model identifier + load_from = 'hf-hub' + assert hf_hub_id + pretrained_loc = hf_hub_id + else: + # default source == timm or unspecified + if pretrained_file: + load_from = 'file' + pretrained_loc = pretrained_file + elif pretrained_url: + load_from = 'url' + pretrained_loc = pretrained_url + elif hf_hub_id and has_hf_hub(necessary=True): + # hf-hub available as alternate weight source in default_cfg + load_from = 'hf-hub' + pretrained_loc = hf_hub_id + if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None): + # if a filename override is set, return tuple for location w/ (hub_id, filename) + pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] + return load_from, pretrained_loc + + +def set_pretrained_download_progress(enable=True): + """ Set download progress for pretrained weights on/off (globally). """ + global _DOWNLOAD_PROGRESS + _DOWNLOAD_PROGRESS = enable + + +def set_pretrained_check_hash(enable=True): + """ Set hash checking for pretrained weights on/off (globally). """ + global _CHECK_HASH + _CHECK_HASH = enable + + +def load_custom_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + load_fn: Optional[Callable] = None, +): + r"""Loads a custom (read non .pth) weight file + + Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls + a passed in custom load fun, or the `load_pretrained` model member fn. + + If the object is already present in `model_dir`, it's deserialized and returned. + The default value of `model_dir` is ``/checkpoints`` where + `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + model: The instantiated model to load weights into + pretrained_cfg (dict): Default pretrained model cfg + load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named + 'laod_pretrained' on the model will be called if it exists + """ + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) + if not pretrained_cfg: + _logger.warning("Invalid pretrained config, cannot load weights.") + return + + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if not load_from: + _logger.warning("No pretrained weights exist for this model. Using random initialization.") + return + if load_from == 'hf-hub': # FIXME + _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") + elif load_from == 'url': + pretrained_loc = download_cached_file( + pretrained_loc, + check_hash=_CHECK_HASH, + progress=_DOWNLOAD_PROGRESS + ) + + if load_fn is not None: + load_fn(model, pretrained_loc) + elif hasattr(model, 'load_pretrained'): + model.load_pretrained(pretrained_loc) + else: + _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") + + +def load_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + num_classes: int = 1000, + in_chans: int = 3, + filter_fn: Optional[Callable] = None, + strict: bool = True, +): + """ Load pretrained checkpoint + + Args: + model (nn.Module) : PyTorch model module + pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset + num_classes (int): num_classes for target model + in_chans (int): in_chans for target model + filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) + strict (bool): strict load of checkpoint + + """ + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) + if not pretrained_cfg: + _logger.warning("Invalid pretrained config, cannot load weights.") + return + + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if load_from == 'file': + _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') + state_dict = load_state_dict(pretrained_loc) + elif load_from == 'url': + _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') + state_dict = load_state_dict_from_url( + pretrained_loc, + map_location='cpu', + progress=_DOWNLOAD_PROGRESS, + check_hash=_CHECK_HASH, + ) + elif load_from == 'hf-hub': + _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') + if isinstance(pretrained_loc, (list, tuple)): + state_dict = load_state_dict_from_hf(*pretrained_loc) + else: + state_dict = load_state_dict_from_hf(pretrained_loc) + else: + _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") + return + + if filter_fn is not None: + # for backwards compat with filter fn that take one arg, try one first, the two + try: + state_dict = filter_fn(state_dict) + except TypeError: + state_dict = filter_fn(state_dict, model) + + input_convs = pretrained_cfg.get('first_conv', None) + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + try: + state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) + _logger.info( + f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + _logger.warning( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') + + classifiers = pretrained_cfg.get('classifier', None) + label_offset = pretrained_cfg.get('label_offset', 0) + if classifiers is not None: + if isinstance(classifiers, str): + classifiers = (classifiers,) + if num_classes != pretrained_cfg['num_classes']: + for classifier_name in classifiers: + # completely discard fully connected if model num_classes doesn't match pretrained weights + state_dict.pop(classifier_name + '.weight', None) + state_dict.pop(classifier_name + '.bias', None) + strict = False + elif label_offset > 0: + for classifier_name in classifiers: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + + model.load_state_dict(state_dict, strict=strict) + + +def pretrained_cfg_for_features(pretrained_cfg): + pretrained_cfg = deepcopy(pretrained_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size? + for tr in to_remove: + pretrained_cfg.pop(tr, None) + return pretrained_cfg + + +def _filter_kwargs(kwargs, names): + if not kwargs or not names: + return + for n in names: + kwargs.pop(n, None) + + +def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter): + """ Update the default_cfg and kwargs before passing to model + + Args: + pretrained_cfg: input pretrained cfg (updated in-place) + kwargs: keyword args passed to model build fn (updated in-place) + kwargs_filter: keyword arg keys that must be removed before model __init__ + """ + # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if pretrained_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + + for n in default_kwarg_names: + # for legacy reasons, model __init__args uses img_size + in_chans as separate args while + # pretrained_cfg has one input_size=(C, H ,W) entry + if n == 'img_size': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[-2:]) + elif n == 'in_chans': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[0]) + else: + default_val = pretrained_cfg.get(n, None) + if default_val is not None: + kwargs.setdefault(n, pretrained_cfg[n]) + + # Filter keyword args for task specific model variants (some 'features only' models, etc.) + _filter_kwargs(kwargs, names=kwargs_filter) + + +def resolve_pretrained_cfg( + variant: str, + pretrained_cfg=None, + pretrained_cfg_overlay=None, +) -> PretrainedCfg: + model_with_tag = variant + pretrained_tag = None + if pretrained_cfg: + if isinstance(pretrained_cfg, dict): + # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg + pretrained_cfg = PretrainedCfg(**pretrained_cfg) + elif isinstance(pretrained_cfg, str): + pretrained_tag = pretrained_cfg + pretrained_cfg = None + + # fallback to looking up pretrained cfg in model registry by variant identifier + if not pretrained_cfg: + if pretrained_tag: + model_with_tag = '.'.join([variant, pretrained_tag]) + pretrained_cfg = get_pretrained_cfg(model_with_tag) + + if not pretrained_cfg: + _logger.warning( + f"No pretrained configuration specified for {model_with_tag} model. Using a default." + f" Please add a config to the model pretrained_cfg registry or pass explicitly.") + pretrained_cfg = PretrainedCfg() # instance with defaults + + pretrained_cfg_overlay = pretrained_cfg_overlay or {} + if not pretrained_cfg.architecture: + pretrained_cfg_overlay.setdefault('architecture', variant) + pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay) + + return pretrained_cfg + + +def build_model_with_cfg( + model_cls: Callable, + variant: str, + pretrained: bool, + pretrained_cfg: Optional[Dict] = None, + pretrained_cfg_overlay: Optional[Dict] = None, + model_cfg: Optional[Any] = None, + feature_cfg: Optional[Dict] = None, + pretrained_strict: bool = True, + pretrained_filter_fn: Optional[Callable] = None, + kwargs_filter: Optional[Tuple[str]] = None, + **kwargs, +): + """ Build model with specified default_cfg and optional model_cfg + + This helper fn aids in the construction of a model including: + * handling default_cfg and associated pretrained weight loading + * passing through optional model_cfg for models with config based arch spec + * features_only model adaptation + * pruning config / model adaptation + + Args: + model_cls (nn.Module): model class + variant (str): model variant name + pretrained (bool): load pretrained weights + pretrained_cfg (dict): model's pretrained weight/task config + model_cfg (Optional[Dict]): model's architecture config + feature_cfg (Optional[Dict]: feature extraction adapter config + pretrained_strict (bool): load pretrained weights strictly + pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights + kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model + **kwargs: model args passed through to model __init__ + """ + pruned = kwargs.pop('pruned', False) + features = False + feature_cfg = feature_cfg or {} + + # resolve and update model pretrained config and model kwargs + pretrained_cfg = resolve_pretrained_cfg( + variant, + pretrained_cfg=pretrained_cfg, + pretrained_cfg_overlay=pretrained_cfg_overlay + ) + + # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model + pretrained_cfg = pretrained_cfg.to_dict() + + _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) + + # Setup for feature extraction wrapper done at end of this fn + if kwargs.pop('features_only', False): + features = True + feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) + if 'out_indices' in kwargs: + feature_cfg['out_indices'] = kwargs.pop('out_indices') + + # Instantiate the model + if model_cfg is None: + model = model_cls(**kwargs) + else: + model = model_cls(cfg=model_cfg, **kwargs) + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat + + if pruned: + model = adapt_model_from_file(model, variant) + + # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) + if pretrained: + if pretrained_cfg.get('custom_load', False): + load_custom_pretrained( + model, + pretrained_cfg=pretrained_cfg, + ) + else: + load_pretrained( + model, + pretrained_cfg=pretrained_cfg, + num_classes=num_classes_pretrained, + in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, + strict=pretrained_strict, + ) + + # Wrap the model in a feature extraction module if enabled + if features: + feature_cls = FeatureListNet + if 'feature_cls' in feature_cfg: + feature_cls = feature_cfg.pop('feature_cls') + if isinstance(feature_cls, str): + feature_cls = feature_cls.lower() + if 'hook' in feature_cls: + feature_cls = FeatureHookNet + elif feature_cls == 'fx': + feature_cls = FeatureGraphNet + else: + assert False, f'Unknown feature class {feature_cls}' + model = feature_cls(model, **feature_cfg) + model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat + + return model diff --git a/timm/models/efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py similarity index 99% rename from timm/models/efficientnet_blocks.py rename to timm/models/_efficientnet_blocks.py index 34a31757..92b849e4 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -2,13 +2,12 @@ Hacked together by / Copyright 2019, Ross Wightman """ -import math import torch import torch.nn as nn from torch.nn import functional as F -from .layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer +from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer __all__ = [ 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] diff --git a/timm/models/efficientnet_builder.py b/timm/models/_efficientnet_builder.py similarity index 99% rename from timm/models/efficientnet_builder.py rename to timm/models/_efficientnet_builder.py index 67d15a86..e6cd05ae 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -14,8 +14,8 @@ from functools import partial import torch.nn as nn -from .efficientnet_blocks import * -from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible +from ._efficientnet_blocks import * +from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] diff --git a/timm/models/factory.py b/timm/models/_factory.py similarity index 94% rename from timm/models/factory.py rename to timm/models/_factory.py index 9e06c1aa..2b050ad6 100644 --- a/timm/models/factory.py +++ b/timm/models/_factory.py @@ -2,11 +2,11 @@ import os from typing import Any, Dict, Optional, Union from urllib.parse import urlsplit -from .pretrained import PretrainedCfg, split_model_name_tag -from .helpers import load_checkpoint -from .hub import load_model_config_from_hf -from .layers import set_layer_config -from .registry import is_model, model_entrypoint +from timm.layers import set_layer_config +from ._pretrained import PretrainedCfg, split_model_name_tag +from ._helpers import load_checkpoint +from ._hub import load_model_config_from_hf +from ._registry import is_model, model_entrypoint def parse_model_name(model_name): diff --git a/timm/models/features.py b/timm/models/_features.py similarity index 100% rename from timm/models/features.py rename to timm/models/_features.py diff --git a/timm/models/fx_features.py b/timm/models/_features_fx.py similarity index 93% rename from timm/models/fx_features.py rename to timm/models/_features_fx.py index b09381b7..2d4a33c2 100644 --- a/timm/models/fx_features.py +++ b/timm/models/_features_fx.py @@ -6,7 +6,7 @@ from typing import Callable, List, Dict, Union, Type import torch from torch import nn -from .features import _get_feature_info +from ._features import _get_feature_info try: from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor @@ -15,9 +15,9 @@ except ImportError: has_fx_feature_extraction = False # Layers we went to treat as leaf modules -from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame -from .layers.non_local_attn import BilinearAttnTransform -from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame +from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame +from timm.layers.non_local_attn import BilinearAttnTransform +from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here # BUT modules from timm.models should use the registration mechanism below @@ -29,7 +29,7 @@ _leaf_modules = { } try: - from .layers import InplaceAbn + from timm.layers import InplaceAbn _leaf_modules.add(InplaceAbn) except ImportError: pass diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py new file mode 100644 index 00000000..2856842d --- /dev/null +++ b/timm/models/_helpers.py @@ -0,0 +1,113 @@ +""" Model creation / weight loading / state_dict helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import os +from collections import OrderedDict + +import torch + +import timm.models._builder + +_logger = logging.getLogger(__name__) + + +def clean_state_dict(state_dict): + # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training + cleaned_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] if k.startswith('module.') else k + cleaned_state_dict[name] = v + return cleaned_state_dict + + +def load_state_dict(checkpoint_path, use_ema=True): + if checkpoint_path and os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + state_dict_key = '' + if isinstance(checkpoint, dict): + if use_ema and checkpoint.get('state_dict_ema', None) is not None: + state_dict_key = 'state_dict_ema' + elif use_ema and checkpoint.get('model_ema', None) is not None: + state_dict_key = 'model_ema' + elif 'state_dict' in checkpoint: + state_dict_key = 'state_dict' + elif 'model' in checkpoint: + state_dict_key = 'model' + state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint) + _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) + return state_dict + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False): + if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, 'load_pretrained'): + timm.models._model_builder.load_pretrained(checkpoint_path) + else: + raise NotImplementedError('Model cannot load numpy checkpoint') + return + state_dict = load_state_dict(checkpoint_path, use_ema) + if remap: + state_dict = remap_checkpoint(model, state_dict) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def remap_checkpoint(model, state_dict, allow_reshape=True): + """ remap checkpoint by iterating over state dicts in order (ignoring original keys). + This assumes models (and originating state dict) were created with params registered in same order. + """ + out_dict = {} + for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): + assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + if va.shape != vb.shape: + if allow_reshape: + vb = vb.reshape(va.shape) + else: + assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + out_dict[ka] = vb + return out_dict + + +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): + resume_epoch = None + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + _logger.info('Restoring model state from checkpoint...') + state_dict = clean_state_dict(checkpoint['state_dict']) + model.load_state_dict(state_dict) + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + else: + model.load_state_dict(checkpoint) + if log_info: + _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + diff --git a/timm/models/hub.py b/timm/models/_hub.py similarity index 99% rename from timm/models/hub.py rename to timm/models/_hub.py index 18c5444a..2a87ae7e 100644 --- a/timm/models/hub.py +++ b/timm/models/_hub.py @@ -15,7 +15,7 @@ except ImportError: from torch.hub import _get_torch_home as get_dir from timm import __version__ -from timm.models.pretrained import filter_pretrained_cfg +from timm.models._pretrained import filter_pretrained_cfg try: from huggingface_hub import ( diff --git a/timm/models/_manipulate.py b/timm/models/_manipulate.py new file mode 100644 index 00000000..82a922a2 --- /dev/null +++ b/timm/models/_manipulate.py @@ -0,0 +1,255 @@ +import collections.abc +import math +import re +from collections import defaultdict +from itertools import chain +from typing import Callable, Union, Dict + +import torch +from torch import nn as nn +from torch.utils.checkpoint import checkpoint + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() + + +def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): + if not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + yield name, module + + +def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False): + if module._parameters and not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules_with_params( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if module._parameters and depth_first and include_root: + yield name, module + + +MATCH_PREV_GROUP = (99999,) + + +def group_with_matcher( + named_objects, + group_matcher: Union[Dict, Callable], + output_values: bool = False, + reverse: bool = False +): + if isinstance(group_matcher, dict): + # dictionary matcher contains a dict of raw-string regex expr that must be compiled + compiled = [] + for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): + if mspec is None: + continue + # map all matching specifications into 3-tuple (compiled re, prefix, suffix) + if isinstance(mspec, (tuple, list)): + # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) + for sspec in mspec: + compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] + else: + compiled += [(re.compile(mspec), (group_ordinal,), None)] + group_matcher = compiled + + def _get_grouping(name): + if isinstance(group_matcher, (list, tuple)): + for match_fn, prefix, suffix in group_matcher: + r = match_fn.match(name) + if r: + parts = (prefix, r.groups(), suffix) + # map all tuple elem to int for numeric sort, filter out None entries + return tuple(map(float, chain.from_iterable(filter(None, parts)))) + return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal + else: + ord = group_matcher(name) + if not isinstance(ord, collections.abc.Iterable): + return ord, + return tuple(ord) + + # map layers into groups via ordinals (ints or tuples of ints) from matcher + grouping = defaultdict(list) + for k, v in named_objects: + grouping[_get_grouping(k)].append(v if output_values else k) + + # remap to integers + layer_id_to_param = defaultdict(list) + lid = -1 + for k in sorted(filter(lambda x: x is not None, grouping.keys())): + if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: + lid += 1 + layer_id_to_param[lid].extend(grouping[k]) + + if reverse: + assert not output_values, "reverse mapping only sensible for name output" + # output reverse mapping + param_to_layer_id = {} + for lid, lm in layer_id_to_param.items(): + for n in lm: + param_to_layer_id[n] = lid + return param_to_layer_id + + return layer_id_to_param + + +def group_parameters( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) + + +def group_modules( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) + + +def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'): + prefix_is_tuple = isinstance(prefix, tuple) + if isinstance(module_types, str): + if module_types == 'container': + module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict) + else: + module_types = (nn.Sequential,) + for name, module in named_modules: + if depth and isinstance(module, module_types): + yield from flatten_modules( + module.named_children(), + depth - 1, + prefix=(name,) if prefix_is_tuple else name, + module_types=module_types, + ) + else: + if prefix_is_tuple: + name = prefix + (name,) + yield name, module + else: + if prefix: + name = '.'.join([prefix, name]) + yield name, module + + +def checkpoint_seq( + functions, + x, + every=1, + flatten=False, + skip_last=False, + preserve_rng_state=True +): + r"""A helper function for checkpointing sequential models. + + Sequential models execute a list of modules/functions in order + (sequentially). Therefore, we can divide such a sequence into segments + and checkpoint each segment. All segments except run in :func:`torch.no_grad` + manner, i.e., not storing the intermediate activations. The inputs of each + checkpointed segment will be saved for re-running the segment in the backward pass. + + See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. + + .. warning:: + Checkpointing currently only supports :func:`torch.autograd.backward` + and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` + is not supported. + + .. warning: + At least one of the inputs needs to have :code:`requires_grad=True` if + grads are needed for model inputs, otherwise the checkpointed part of the + model won't have gradients. + + Args: + functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. + x: A Tensor that is input to :attr:`functions` + every: checkpoint every-n functions (default: 1) + flatten (bool): flatten nn.Sequential of nn.Sequentials + skip_last (bool): skip checkpointing the last function in the sequence if True + preserve_rng_state (bool, optional, default=True): Omit stashing and restoring + the RNG state during each checkpoint. + + Returns: + Output of running :attr:`functions` sequentially on :attr:`*inputs` + + Example: + >>> model = nn.Sequential(...) + >>> input_var = checkpoint_seq(model, input_var, every=2) + """ + def run_function(start, end, functions): + def forward(_x): + for j in range(start, end + 1): + _x = functions[j](_x) + return _x + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = functions.children() + if flatten: + functions = chain.from_iterable(functions) + if not isinstance(functions, (tuple, list)): + functions = tuple(functions) + + num_checkpointed = len(functions) + if skip_last: + num_checkpointed -= 1 + end = -1 + for start in range(0, num_checkpointed, every): + end = min(start + every - 1, num_checkpointed - 1) + x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) + if skip_last: + return run_function(end + 1, len(functions) - 1, functions)(x) + return x + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight diff --git a/timm/models/pretrained.py b/timm/models/_pretrained.py similarity index 100% rename from timm/models/pretrained.py rename to timm/models/_pretrained.py diff --git a/timm/models/_prune.py b/timm/models/_prune.py new file mode 100644 index 00000000..0d744e40 --- /dev/null +++ b/timm/models/_prune.py @@ -0,0 +1,111 @@ +import os +from copy import deepcopy + +from torch import nn as nn + +from timm.layers import Conv2dSame, BatchNormAct2d, Linear + + +def extract_layer(model, layer): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + if not hasattr(model, 'module') and layer[0] == 'module': + layer = layer[1:] + for l in layer: + if hasattr(module, l): + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + else: + return module + return module + + +def set_layer(model, layer, val): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + lst_index = 0 + module2 = module + for l in layer: + if hasattr(module2, l): + if not l.isdigit(): + module2 = getattr(module2, l) + else: + module2 = module2[int(l)] + lst_index += 1 + lst_index -= 1 + for l in layer[:lst_index]: + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + l = layer[lst_index] + setattr(module, l, val) + + +def adapt_model_from_string(parent_module, model_string): + separator = '***' + state_dict = {} + lst_shape = model_string.split(separator) + for k in lst_shape: + k = k.split(':') + key = k[0] + shape = k[1][1:-1].split(',') + if shape[0] != '': + state_dict[key] = [int(i) for i in shape] + + new_module = deepcopy(parent_module) + for n, m in parent_module.named_modules(): + old_module = extract_layer(parent_module, n) + if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): + if isinstance(old_module, Conv2dSame): + conv = Conv2dSame + else: + conv = nn.Conv2d + s = state_dict[n + '.weight'] + in_channels = s[1] + out_channels = s[0] + g = 1 + if old_module.groups > 1: + in_channels = out_channels + g = in_channels + new_conv = conv( + in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, + bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, + groups=g, stride=old_module.stride) + set_layer(new_module, n, new_conv) + elif isinstance(old_module, BatchNormAct2d): + new_bn = BatchNormAct2d( + state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + new_bn.drop = old_module.drop + new_bn.act = old_module.act + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.BatchNorm2d): + new_bn = nn.BatchNorm2d( + num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.Linear): + # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? + num_features = state_dict[n + '.weight'][1] + new_fc = Linear( + in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) + set_layer(new_module, n, new_fc) + if hasattr(new_module, 'num_features'): + new_module.num_features = num_features + new_module.eval() + parent_module.eval() + + return new_module + + +def adapt_model_from_file(parent_module, model_variant): + adapt_file = os.path.join(os.path.dirname(__file__), '_pruned', model_variant + '.txt') + with open(adapt_file, 'r') as f: + return adapt_model_from_string(parent_module, f.read().strip()) diff --git a/timm/models/pruned/ecaresnet101d_pruned.txt b/timm/models/_pruned/ecaresnet101d_pruned.txt similarity index 100% rename from timm/models/pruned/ecaresnet101d_pruned.txt rename to timm/models/_pruned/ecaresnet101d_pruned.txt diff --git a/timm/models/pruned/ecaresnet50d_pruned.txt b/timm/models/_pruned/ecaresnet50d_pruned.txt similarity index 100% rename from timm/models/pruned/ecaresnet50d_pruned.txt rename to timm/models/_pruned/ecaresnet50d_pruned.txt diff --git a/timm/models/pruned/efficientnet_b1_pruned.txt b/timm/models/_pruned/efficientnet_b1_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b1_pruned.txt rename to timm/models/_pruned/efficientnet_b1_pruned.txt diff --git a/timm/models/pruned/efficientnet_b2_pruned.txt b/timm/models/_pruned/efficientnet_b2_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b2_pruned.txt rename to timm/models/_pruned/efficientnet_b2_pruned.txt diff --git a/timm/models/pruned/efficientnet_b3_pruned.txt b/timm/models/_pruned/efficientnet_b3_pruned.txt similarity index 100% rename from timm/models/pruned/efficientnet_b3_pruned.txt rename to timm/models/_pruned/efficientnet_b3_pruned.txt diff --git a/timm/models/registry.py b/timm/models/_registry.py similarity index 95% rename from timm/models/registry.py rename to timm/models/_registry.py index 159ffb5f..97c8fd59 100644 --- a/timm/models/registry.py +++ b/timm/models/_registry.py @@ -9,7 +9,7 @@ from collections import defaultdict, deque from copy import deepcopy from typing import List, Optional, Union, Tuple -from .pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag +from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag __all__ = [ 'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', @@ -167,10 +167,12 @@ def is_model(model_name): return arch_name in _model_entrypoints -def model_entrypoint(model_name): +def model_entrypoint(model_name, module_filter: Optional[str] = None): """Fetch a model entrypoint for specified model name """ arch_name = get_arch_name(model_name) + if module_filter and arch_name not in _module_to_models.get(module_filter, {}): + raise RuntimeError(f'Model ({model_name} not found in module {module_filter}.') return _model_entrypoints[arch_name] diff --git a/timm/models/beit.py b/timm/models/beit.py index 1f6bf82b..7c4dd14d 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -46,12 +46,13 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model from .vision_transformer import checkpoint_filter_fn +__all__ = ['Beit'] def _cfg(url='', **kwargs): return { diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 3815fa30..c67144cc 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -13,9 +13,9 @@ Consider all of the models definitions here as experimental WIP and likely to ch Hacked together by / copyright Ross Wightman, 2021. """ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from ._builder import build_model_with_cfg +from ._registry import register_model from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks -from .helpers import build_model_with_cfg -from .registry import register_model __all__ = [] diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 1e402629..0e5c9c7f 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -26,18 +26,18 @@ Hacked together by / copyright Ross Wightman, 2021. """ import math from dataclasses import dataclass, field, replace -from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence from functools import partial +from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\ - EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d -from .registry import register_model +from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ + create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] diff --git a/timm/models/cait.py b/timm/models/cait.py index c0892099..15dcd956 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -8,17 +8,16 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W """ # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ -from .registry import register_model - +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn'] diff --git a/timm/models/coat.py b/timm/models/coat.py index c3071a6c..4ed6d8e8 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -7,7 +7,6 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT Modified from timm/models/vision_transformer.py """ -from copy import deepcopy from functools import partial from typing import Tuple, List, Union @@ -16,19 +15,11 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ -from .registry import register_model -from .layers import _assert - - -__all__ = [ - "coat_tiny", - "coat_mini", - "coat_lite_tiny", - "coat_lite_mini", - "coat_lite_small" -] +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['CoaT'] def _cfg_coat(url='', **kwargs): diff --git a/timm/models/convit.py b/timm/models/convit.py index 26849f6e..d117ccdc 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -22,20 +22,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ''' +from functools import partial + import torch import torch.nn as nn -from functools import partial -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .vision_transformer_hybrid import HybridEmbed -from .fx_features import register_notrace_module -import torch -import torch.nn as nn + +__all__ = ['ConViT'] def _cfg(url='', **kwargs): diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index e7e2481a..3a8c6cf5 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -5,9 +5,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import SelectAdaptivePool2d +from timm.layers import SelectAdaptivePool2d +from ._registry import register_model +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq + +__all__ = ['ConvMixer'] def _cfg(url='', **kwargs): diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 36a484b3..eea5782a 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -18,12 +18,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ +from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ create_conv2d, get_act_layer, make_divisible, to_ntuple -from .pretrained import generate_default_cfgs -from .registry import register_model - +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 764eb3fe..908fcf6d 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -24,21 +24,22 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ +from functools import partial +from typing import List from typing import Tuple import torch -import torch.nn as nn -import torch.nn.functional as F import torch.hub -from functools import partial -from typing import List +import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, trunc_normal_, _assert -from .registry import register_model -from .vision_transformer import Mlp, Block +from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model +from .vision_transformer import Block + +__all__ = ['CrossViT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 2c09e7e3..280f929e 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -12,20 +12,18 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage Hacked together by / Copyright 2020 Ross Wightman """ -import collections.abc -from dataclasses import dataclass, field, asdict +from dataclasses import dataclass, asdict from functools import partial -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP -from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible -from .registry import register_model - +from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, MATCH_PREV_GROUP +from ._registry import register_model __all__ = ['CspNet'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/deit.py b/timm/models/deit.py index 3205b024..24fbbe56 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -17,9 +17,11 @@ from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model +__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 1afdfd7b..e731f7b0 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -4,7 +4,6 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool. """ import re from collections import OrderedDict -from functools import partial import torch import torch.nn as nn @@ -13,9 +12,10 @@ import torch.utils.checkpoint as cp from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, MATCH_PREV_GROUP -from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier -from .registry import register_model +from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import MATCH_PREV_GROUP +from ._registry import register_model __all__ = ['DenseNet'] diff --git a/timm/models/dla.py b/timm/models/dla.py index 0ab807c0..204fcb4b 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -13,9 +13,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['DLA'] diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 95159729..87bd918f 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -15,9 +15,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier -from .registry import register_model +from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['DPN'] diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 422d4f2c..d90471fb 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -8,20 +8,20 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt Modifications and additions for timm by / Copyright 2022, Ross Wightman """ import math -import torch from collections import OrderedDict from functools import partial from typing import Tuple -from torch import nn +import torch import torch.nn.functional as F +from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module -from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d -from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .registry import register_model - +from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model __all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 4749d93a..4f33f29a 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -18,9 +18,11 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, trunc_normal_, to_2tuple, Mlp -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 3c0efc96..a1324ae3 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -42,15 +42,15 @@ import torch import torch.nn as nn import torch.nn.functional as F - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ +from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq -from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct -from .registry import register_model +from ._features import FeatureInfo, FeatureHooks +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['EfficientNet', 'EfficientNetFeatures'] diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index fb375e2c..ec9b7e5e 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -28,12 +28,13 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\ +from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ get_attn, get_act_layer, get_norm_layer, _assert -from .registry import register_model -from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply +from ._registry import register_model +from .vision_transformer_relpos import RelPosBias # FIXME move to common location __all__ = ['GlobalContextVit'] diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index e19af88b..492049b9 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -11,13 +11,12 @@ import torch import torch.nn as nn import torch.nn.functional as F - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import SelectAdaptivePool2d, Linear, make_divisible -from .efficientnet_blocks import SqueezeExcite, ConvBnAct -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model - +from timm.layers import SelectAdaptivePool2d, Linear, make_divisible +from ._builder import build_model_with_cfg +from ._efficientnet_blocks import SqueezeExcite, ConvBnAct +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['GhostNet'] diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index a1e73554..2b4131fb 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -5,11 +5,13 @@ by Ross Wightman """ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SEModule -from .registry import register_model +from timm.layers import SEModule +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet, Bottleneck, BasicBlock +__all__ = [] + def _cfg(url='', **kwargs): return { diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index a9c946b2..b487d0fd 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -13,9 +13,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier, get_padding -from .registry import register_model +from timm.layers import create_classifier, get_padding +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['Xception65'] diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py index 132eeab4..d77e642a 100644 --- a/timm/models/hardcorenas.py +++ b/timm/models/hardcorenas.py @@ -3,12 +3,14 @@ from functools import partial import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels -from .helpers import build_model_with_cfg, pretrained_cfg_for_features -from .layers import get_act_fn +from ._builder import build_model_with_cfg +from ._builder import pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels +from ._registry import register_model from .mobilenetv3 import MobileNetV3, MobileNetV3Features -from .registry import register_model + +__all__ = [] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/helpers.py b/timm/models/helpers.py deleted file mode 100644 index 2a5551e0..00000000 --- a/timm/models/helpers.py +++ /dev/null @@ -1,855 +0,0 @@ -""" Model creation / weight loading / state_dict helpers - -Hacked together by / Copyright 2020 Ross Wightman -""" -import collections.abc -import dataclasses -import logging -import math -import os -import re -from collections import OrderedDict, defaultdict -from copy import deepcopy -from itertools import chain -from typing import Any, Callable, Optional, Tuple, Dict, Union - -import torch -import torch.nn as nn -from torch.hub import load_state_dict_from_url -from torch.utils.checkpoint import checkpoint - -from .pretrained import PretrainedCfg -from .features import FeatureListNet, FeatureDictNet, FeatureHookNet -from .fx_features import FeatureGraphNet -from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf -from .layers import Conv2dSame, Linear, BatchNormAct2d -from .registry import get_pretrained_cfg - - -_logger = logging.getLogger(__name__) - - -# Global variables for rarely used pretrained checkpoint download progress and hash check. -# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle. -_DOWNLOAD_PROGRESS = False -_CHECK_HASH = False - - -def clean_state_dict(state_dict): - # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training - cleaned_state_dict = OrderedDict() - for k, v in state_dict.items(): - name = k[7:] if k.startswith('module.') else k - cleaned_state_dict[name] = v - return cleaned_state_dict - - -def load_state_dict(checkpoint_path, use_ema=True): - if checkpoint_path and os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') - state_dict_key = '' - if isinstance(checkpoint, dict): - if use_ema and checkpoint.get('state_dict_ema', None) is not None: - state_dict_key = 'state_dict_ema' - elif use_ema and checkpoint.get('model_ema', None) is not None: - state_dict_key = 'model_ema' - elif 'state_dict' in checkpoint: - state_dict_key = 'state_dict' - elif 'model' in checkpoint: - state_dict_key = 'model' - state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint) - _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) - return state_dict - else: - _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) - raise FileNotFoundError() - - -def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False): - if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): - # numpy checkpoint, try to load via model specific load_pretrained fn - if hasattr(model, 'load_pretrained'): - model.load_pretrained(checkpoint_path) - else: - raise NotImplementedError('Model cannot load numpy checkpoint') - return - state_dict = load_state_dict(checkpoint_path, use_ema) - if remap: - state_dict = remap_checkpoint(model, state_dict) - incompatible_keys = model.load_state_dict(state_dict, strict=strict) - return incompatible_keys - - -def remap_checkpoint(model, state_dict, allow_reshape=True): - """ remap checkpoint by iterating over state dicts in order (ignoring original keys). - This assumes models (and originating state dict) were created with params registered in same order. - """ - out_dict = {} - for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): - assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' - if va.shape != vb.shape: - if allow_reshape: - vb = vb.reshape(va.shape) - else: - assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' - out_dict[ka] = vb - return out_dict - - -def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): - resume_epoch = None - if os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - if log_info: - _logger.info('Restoring model state from checkpoint...') - state_dict = clean_state_dict(checkpoint['state_dict']) - model.load_state_dict(state_dict) - - if optimizer is not None and 'optimizer' in checkpoint: - if log_info: - _logger.info('Restoring optimizer state from checkpoint...') - optimizer.load_state_dict(checkpoint['optimizer']) - - if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: - if log_info: - _logger.info('Restoring AMP loss scaler state from checkpoint...') - loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) - - if 'epoch' in checkpoint: - resume_epoch = checkpoint['epoch'] - if 'version' in checkpoint and checkpoint['version'] > 1: - resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save - - if log_info: - _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) - else: - model.load_state_dict(checkpoint) - if log_info: - _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) - return resume_epoch - else: - _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) - raise FileNotFoundError() - - -def _resolve_pretrained_source(pretrained_cfg): - cfg_source = pretrained_cfg.get('source', '') - pretrained_url = pretrained_cfg.get('url', None) - pretrained_file = pretrained_cfg.get('file', None) - hf_hub_id = pretrained_cfg.get('hf_hub_id', None) - # resolve where to load pretrained weights from - load_from = '' - pretrained_loc = '' - if cfg_source == 'hf-hub' and has_hf_hub(necessary=True): - # hf-hub specified as source via model identifier - load_from = 'hf-hub' - assert hf_hub_id - pretrained_loc = hf_hub_id - else: - # default source == timm or unspecified - if pretrained_file: - load_from = 'file' - pretrained_loc = pretrained_file - elif pretrained_url: - load_from = 'url' - pretrained_loc = pretrained_url - elif hf_hub_id and has_hf_hub(necessary=True): - # hf-hub available as alternate weight source in default_cfg - load_from = 'hf-hub' - pretrained_loc = hf_hub_id - if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None): - # if a filename override is set, return tuple for location w/ (hub_id, filename) - pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] - return load_from, pretrained_loc - - -def set_pretrained_download_progress(enable=True): - """ Set download progress for pretrained weights on/off (globally). """ - global _DOWNLOAD_PROGRESS - _DOWNLOAD_PROGRESS = enable - - -def set_pretrained_check_hash(enable=True): - """ Set hash checking for pretrained weights on/off (globally). """ - global _CHECK_HASH - _CHECK_HASH = enable - - -def load_custom_pretrained( - model: nn.Module, - pretrained_cfg: Optional[Dict] = None, - load_fn: Optional[Callable] = None, -): - r"""Loads a custom (read non .pth) weight file - - Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls - a passed in custom load fun, or the `load_pretrained` model member fn. - - If the object is already present in `model_dir`, it's deserialized and returned. - The default value of `model_dir` is ``/checkpoints`` where - `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. - - Args: - model: The instantiated model to load weights into - pretrained_cfg (dict): Default pretrained model cfg - load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named - 'laod_pretrained' on the model will be called if it exists - """ - pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) - if not pretrained_cfg: - _logger.warning("Invalid pretrained config, cannot load weights.") - return - - load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) - if not load_from: - _logger.warning("No pretrained weights exist for this model. Using random initialization.") - return - if load_from == 'hf-hub': # FIXME - _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") - elif load_from == 'url': - pretrained_loc = download_cached_file( - pretrained_loc, - check_hash=_CHECK_HASH, - progress=_DOWNLOAD_PROGRESS - ) - - if load_fn is not None: - load_fn(model, pretrained_loc) - elif hasattr(model, 'load_pretrained'): - model.load_pretrained(pretrained_loc) - else: - _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") - - -def adapt_input_conv(in_chans, conv_weight): - conv_type = conv_weight.dtype - conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU - O, I, J, K = conv_weight.shape - if in_chans == 1: - if I > 3: - assert conv_weight.shape[1] % 3 == 0 - # For models with space2depth stems - conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) - conv_weight = conv_weight.sum(dim=2, keepdim=False) - else: - conv_weight = conv_weight.sum(dim=1, keepdim=True) - elif in_chans != 3: - if I != 3: - raise NotImplementedError('Weight format not supported by conversion.') - else: - # NOTE this strategy should be better than random init, but there could be other combinations of - # the original RGB input layer weights that'd work better for specific cases. - repeat = int(math.ceil(in_chans / 3)) - conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] - conv_weight *= (3 / float(in_chans)) - conv_weight = conv_weight.to(conv_type) - return conv_weight - - -def load_pretrained( - model: nn.Module, - pretrained_cfg: Optional[Dict] = None, - num_classes: int = 1000, - in_chans: int = 3, - filter_fn: Optional[Callable] = None, - strict: bool = True, -): - """ Load pretrained checkpoint - - Args: - model (nn.Module) : PyTorch model module - pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset - num_classes (int): num_classes for target model - in_chans (int): in_chans for target model - filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) - strict (bool): strict load of checkpoint - - """ - pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) - if not pretrained_cfg: - _logger.warning("Invalid pretrained config, cannot load weights.") - return - - load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) - if load_from == 'file': - _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') - state_dict = load_state_dict(pretrained_loc) - elif load_from == 'url': - _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') - state_dict = load_state_dict_from_url( - pretrained_loc, - map_location='cpu', - progress=_DOWNLOAD_PROGRESS, - check_hash=_CHECK_HASH, - ) - elif load_from == 'hf-hub': - _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') - if isinstance(pretrained_loc, (list, tuple)): - state_dict = load_state_dict_from_hf(*pretrained_loc) - else: - state_dict = load_state_dict_from_hf(pretrained_loc) - else: - _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") - return - - if filter_fn is not None: - # for backwards compat with filter fn that take one arg, try one first, the two - try: - state_dict = filter_fn(state_dict) - except TypeError: - state_dict = filter_fn(state_dict, model) - - input_convs = pretrained_cfg.get('first_conv', None) - if input_convs is not None and in_chans != 3: - if isinstance(input_convs, str): - input_convs = (input_convs,) - for input_conv_name in input_convs: - weight_name = input_conv_name + '.weight' - try: - state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) - _logger.info( - f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') - except NotImplementedError as e: - del state_dict[weight_name] - strict = False - _logger.warning( - f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') - - classifiers = pretrained_cfg.get('classifier', None) - label_offset = pretrained_cfg.get('label_offset', 0) - if classifiers is not None: - if isinstance(classifiers, str): - classifiers = (classifiers,) - if num_classes != pretrained_cfg['num_classes']: - for classifier_name in classifiers: - # completely discard fully connected if model num_classes doesn't match pretrained weights - state_dict.pop(classifier_name + '.weight', None) - state_dict.pop(classifier_name + '.bias', None) - strict = False - elif label_offset > 0: - for classifier_name in classifiers: - # special case for pretrained weights with an extra background class in pretrained weights - classifier_weight = state_dict[classifier_name + '.weight'] - state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] - classifier_bias = state_dict[classifier_name + '.bias'] - state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] - - model.load_state_dict(state_dict, strict=strict) - - -def extract_layer(model, layer): - layer = layer.split('.') - module = model - if hasattr(model, 'module') and layer[0] != 'module': - module = model.module - if not hasattr(model, 'module') and layer[0] == 'module': - layer = layer[1:] - for l in layer: - if hasattr(module, l): - if not l.isdigit(): - module = getattr(module, l) - else: - module = module[int(l)] - else: - return module - return module - - -def set_layer(model, layer, val): - layer = layer.split('.') - module = model - if hasattr(model, 'module') and layer[0] != 'module': - module = model.module - lst_index = 0 - module2 = module - for l in layer: - if hasattr(module2, l): - if not l.isdigit(): - module2 = getattr(module2, l) - else: - module2 = module2[int(l)] - lst_index += 1 - lst_index -= 1 - for l in layer[:lst_index]: - if not l.isdigit(): - module = getattr(module, l) - else: - module = module[int(l)] - l = layer[lst_index] - setattr(module, l, val) - - -def adapt_model_from_string(parent_module, model_string): - separator = '***' - state_dict = {} - lst_shape = model_string.split(separator) - for k in lst_shape: - k = k.split(':') - key = k[0] - shape = k[1][1:-1].split(',') - if shape[0] != '': - state_dict[key] = [int(i) for i in shape] - - new_module = deepcopy(parent_module) - for n, m in parent_module.named_modules(): - old_module = extract_layer(parent_module, n) - if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): - if isinstance(old_module, Conv2dSame): - conv = Conv2dSame - else: - conv = nn.Conv2d - s = state_dict[n + '.weight'] - in_channels = s[1] - out_channels = s[0] - g = 1 - if old_module.groups > 1: - in_channels = out_channels - g = in_channels - new_conv = conv( - in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, - bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, - groups=g, stride=old_module.stride) - set_layer(new_module, n, new_conv) - elif isinstance(old_module, BatchNormAct2d): - new_bn = BatchNormAct2d( - state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, - affine=old_module.affine, track_running_stats=True) - new_bn.drop = old_module.drop - new_bn.act = old_module.act - set_layer(new_module, n, new_bn) - elif isinstance(old_module, nn.BatchNorm2d): - new_bn = nn.BatchNorm2d( - num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, - affine=old_module.affine, track_running_stats=True) - set_layer(new_module, n, new_bn) - elif isinstance(old_module, nn.Linear): - # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? - num_features = state_dict[n + '.weight'][1] - new_fc = Linear( - in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) - set_layer(new_module, n, new_fc) - if hasattr(new_module, 'num_features'): - new_module.num_features = num_features - new_module.eval() - parent_module.eval() - - return new_module - - -def adapt_model_from_file(parent_module, model_variant): - adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') - with open(adapt_file, 'r') as f: - return adapt_model_from_string(parent_module, f.read().strip()) - - -def pretrained_cfg_for_features(pretrained_cfg): - pretrained_cfg = deepcopy(pretrained_cfg) - # remove default pretrained cfg fields that don't have much relevance for feature backbone - to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size? - for tr in to_remove: - pretrained_cfg.pop(tr, None) - return pretrained_cfg - - -def _filter_kwargs(kwargs, names): - if not kwargs or not names: - return - for n in names: - kwargs.pop(n, None) - - -def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter): - """ Update the default_cfg and kwargs before passing to model - - Args: - pretrained_cfg: input pretrained cfg (updated in-place) - kwargs: keyword args passed to model build fn (updated in-place) - kwargs_filter: keyword arg keys that must be removed before model __init__ - """ - # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) - default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') - if pretrained_cfg.get('fixed_input_size', False): - # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size - default_kwarg_names += ('img_size',) - - for n in default_kwarg_names: - # for legacy reasons, model __init__args uses img_size + in_chans as separate args while - # pretrained_cfg has one input_size=(C, H ,W) entry - if n == 'img_size': - input_size = pretrained_cfg.get('input_size', None) - if input_size is not None: - assert len(input_size) == 3 - kwargs.setdefault(n, input_size[-2:]) - elif n == 'in_chans': - input_size = pretrained_cfg.get('input_size', None) - if input_size is not None: - assert len(input_size) == 3 - kwargs.setdefault(n, input_size[0]) - else: - default_val = pretrained_cfg.get(n, None) - if default_val is not None: - kwargs.setdefault(n, pretrained_cfg[n]) - - # Filter keyword args for task specific model variants (some 'features only' models, etc.) - _filter_kwargs(kwargs, names=kwargs_filter) - - -def resolve_pretrained_cfg( - variant: str, - pretrained_cfg=None, - pretrained_cfg_overlay=None, -) -> PretrainedCfg: - model_with_tag = variant - pretrained_tag = None - if pretrained_cfg: - if isinstance(pretrained_cfg, dict): - # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg - pretrained_cfg = PretrainedCfg(**pretrained_cfg) - elif isinstance(pretrained_cfg, str): - pretrained_tag = pretrained_cfg - pretrained_cfg = None - - # fallback to looking up pretrained cfg in model registry by variant identifier - if not pretrained_cfg: - if pretrained_tag: - model_with_tag = '.'.join([variant, pretrained_tag]) - pretrained_cfg = get_pretrained_cfg(model_with_tag) - - if not pretrained_cfg: - _logger.warning( - f"No pretrained configuration specified for {model_with_tag} model. Using a default." - f" Please add a config to the model pretrained_cfg registry or pass explicitly.") - pretrained_cfg = PretrainedCfg() # instance with defaults - - pretrained_cfg_overlay = pretrained_cfg_overlay or {} - if not pretrained_cfg.architecture: - pretrained_cfg_overlay.setdefault('architecture', variant) - pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay) - - return pretrained_cfg - - -def build_model_with_cfg( - model_cls: Callable, - variant: str, - pretrained: bool, - pretrained_cfg: Optional[Dict] = None, - pretrained_cfg_overlay: Optional[Dict] = None, - model_cfg: Optional[Any] = None, - feature_cfg: Optional[Dict] = None, - pretrained_strict: bool = True, - pretrained_filter_fn: Optional[Callable] = None, - kwargs_filter: Optional[Tuple[str]] = None, - **kwargs, -): - """ Build model with specified default_cfg and optional model_cfg - - This helper fn aids in the construction of a model including: - * handling default_cfg and associated pretrained weight loading - * passing through optional model_cfg for models with config based arch spec - * features_only model adaptation - * pruning config / model adaptation - - Args: - model_cls (nn.Module): model class - variant (str): model variant name - pretrained (bool): load pretrained weights - pretrained_cfg (dict): model's pretrained weight/task config - model_cfg (Optional[Dict]): model's architecture config - feature_cfg (Optional[Dict]: feature extraction adapter config - pretrained_strict (bool): load pretrained weights strictly - pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights - kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model - **kwargs: model args passed through to model __init__ - """ - pruned = kwargs.pop('pruned', False) - features = False - feature_cfg = feature_cfg or {} - - # resolve and update model pretrained config and model kwargs - pretrained_cfg = resolve_pretrained_cfg( - variant, - pretrained_cfg=pretrained_cfg, - pretrained_cfg_overlay=pretrained_cfg_overlay - ) - - # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model - pretrained_cfg = pretrained_cfg.to_dict() - - _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) - - # Setup for feature extraction wrapper done at end of this fn - if kwargs.pop('features_only', False): - features = True - feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) - if 'out_indices' in kwargs: - feature_cfg['out_indices'] = kwargs.pop('out_indices') - - # Instantiate the model - if model_cfg is None: - model = model_cls(**kwargs) - else: - model = model_cls(cfg=model_cfg, **kwargs) - model.pretrained_cfg = pretrained_cfg - model.default_cfg = model.pretrained_cfg # alias for backwards compat - - if pruned: - model = adapt_model_from_file(model, variant) - - # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats - num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) - if pretrained: - if pretrained_cfg.get('custom_load', False): - load_custom_pretrained( - model, - pretrained_cfg=pretrained_cfg, - ) - else: - load_pretrained( - model, - pretrained_cfg=pretrained_cfg, - num_classes=num_classes_pretrained, - in_chans=kwargs.get('in_chans', 3), - filter_fn=pretrained_filter_fn, - strict=pretrained_strict, - ) - - # Wrap the model in a feature extraction module if enabled - if features: - feature_cls = FeatureListNet - if 'feature_cls' in feature_cfg: - feature_cls = feature_cfg.pop('feature_cls') - if isinstance(feature_cls, str): - feature_cls = feature_cls.lower() - if 'hook' in feature_cls: - feature_cls = FeatureHookNet - elif feature_cls == 'fx': - feature_cls = FeatureGraphNet - else: - assert False, f'Unknown feature class {feature_cls}' - model = feature_cls(model, **feature_cfg) - model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg - model.default_cfg = model.pretrained_cfg # alias for backwards compat - - return model - - -def model_parameters(model, exclude_head=False): - if exclude_head: - # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering - return [p for p in model.parameters()][:-2] - else: - return model.parameters() - - -def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: - if not depth_first and include_root: - fn(module=module, name=name) - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - fn(module=module, name=name) - return module - - -def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): - if not depth_first and include_root: - yield name, module - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - yield from named_modules( - module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - yield name, module - - -def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False): - if module._parameters and not depth_first and include_root: - yield name, module - for child_name, child_module in module.named_children(): - child_name = '.'.join((name, child_name)) if name else child_name - yield from named_modules_with_params( - module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if module._parameters and depth_first and include_root: - yield name, module - - -MATCH_PREV_GROUP = (99999,) - - -def group_with_matcher( - named_objects, - group_matcher: Union[Dict, Callable], - output_values: bool = False, - reverse: bool = False -): - if isinstance(group_matcher, dict): - # dictionary matcher contains a dict of raw-string regex expr that must be compiled - compiled = [] - for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): - if mspec is None: - continue - # map all matching specifications into 3-tuple (compiled re, prefix, suffix) - if isinstance(mspec, (tuple, list)): - # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) - for sspec in mspec: - compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] - else: - compiled += [(re.compile(mspec), (group_ordinal,), None)] - group_matcher = compiled - - def _get_grouping(name): - if isinstance(group_matcher, (list, tuple)): - for match_fn, prefix, suffix in group_matcher: - r = match_fn.match(name) - if r: - parts = (prefix, r.groups(), suffix) - # map all tuple elem to int for numeric sort, filter out None entries - return tuple(map(float, chain.from_iterable(filter(None, parts)))) - return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal - else: - ord = group_matcher(name) - if not isinstance(ord, collections.abc.Iterable): - return ord, - return tuple(ord) - - # map layers into groups via ordinals (ints or tuples of ints) from matcher - grouping = defaultdict(list) - for k, v in named_objects: - grouping[_get_grouping(k)].append(v if output_values else k) - - # remap to integers - layer_id_to_param = defaultdict(list) - lid = -1 - for k in sorted(filter(lambda x: x is not None, grouping.keys())): - if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: - lid += 1 - layer_id_to_param[lid].extend(grouping[k]) - - if reverse: - assert not output_values, "reverse mapping only sensible for name output" - # output reverse mapping - param_to_layer_id = {} - for lid, lm in layer_id_to_param.items(): - for n in lm: - param_to_layer_id[n] = lid - return param_to_layer_id - - return layer_id_to_param - - -def group_parameters( - module: nn.Module, - group_matcher, - output_values=False, - reverse=False, -): - return group_with_matcher( - module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) - - -def group_modules( - module: nn.Module, - group_matcher, - output_values=False, - reverse=False, -): - return group_with_matcher( - named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) - - -def checkpoint_seq( - functions, - x, - every=1, - flatten=False, - skip_last=False, - preserve_rng_state=True -): - r"""A helper function for checkpointing sequential models. - - Sequential models execute a list of modules/functions in order - (sequentially). Therefore, we can divide such a sequence into segments - and checkpoint each segment. All segments except run in :func:`torch.no_grad` - manner, i.e., not storing the intermediate activations. The inputs of each - checkpointed segment will be saved for re-running the segment in the backward pass. - - See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. - - .. warning:: - Checkpointing currently only supports :func:`torch.autograd.backward` - and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` - is not supported. - - .. warning: - At least one of the inputs needs to have :code:`requires_grad=True` if - grads are needed for model inputs, otherwise the checkpointed part of the - model won't have gradients. - - Args: - functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. - x: A Tensor that is input to :attr:`functions` - every: checkpoint every-n functions (default: 1) - flatten (bool): flatten nn.Sequential of nn.Sequentials - skip_last (bool): skip checkpointing the last function in the sequence if True - preserve_rng_state (bool, optional, default=True): Omit stashing and restoring - the RNG state during each checkpoint. - - Returns: - Output of running :attr:`functions` sequentially on :attr:`*inputs` - - Example: - >>> model = nn.Sequential(...) - >>> input_var = checkpoint_seq(model, input_var, every=2) - """ - def run_function(start, end, functions): - def forward(_x): - for j in range(start, end + 1): - _x = functions[j](_x) - return _x - return forward - - if isinstance(functions, torch.nn.Sequential): - functions = functions.children() - if flatten: - functions = chain.from_iterable(functions) - if not isinstance(functions, (tuple, list)): - functions = tuple(functions) - - num_checkpointed = len(functions) - if skip_last: - num_checkpointed -= 1 - end = -1 - for start in range(0, num_checkpointed, every): - end = min(start + every - 1, num_checkpointed - 1) - x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) - if skip_last: - return run_function(end + 1, len(functions) - 1, functions)(x) - return x - - -def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'): - prefix_is_tuple = isinstance(prefix, tuple) - if isinstance(module_types, str): - if module_types == 'container': - module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict) - else: - module_types = (nn.Sequential,) - for name, module in named_modules: - if depth and isinstance(module, module_types): - yield from flatten_modules( - module.named_children(), - depth - 1, - prefix=(name,) if prefix_is_tuple else name, - module_types=module_types, - ) - else: - if prefix_is_tuple: - name = prefix + (name,) - yield name, module - else: - if prefix: - name = '.'.join([prefix, name]) - yield name, module diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 30860120..338d409e 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -16,12 +16,14 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .features import FeatureInfo -from .helpers import build_model_with_cfg, pretrained_cfg_for_features -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._features import FeatureInfo +from ._registry import register_model from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE +__all__ = ['HighResolutionNet', 'HighResolutionNetFeatures'] # model_registry will add each entrypoint fn to this + _BN_MOMENTUM = 0.1 _logger = logging.getLogger(__name__) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index fa7b8ec8..3006f3d2 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -7,9 +7,10 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, flatten_modules -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import flatten_modules +from ._registry import register_model __all__ = ['InceptionResnetV2'] diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index c70bd608..28794ce6 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -8,9 +8,13 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, flatten_modules -from .registry import register_model -from .layers import trunc_normal_, create_classifier, Linear +from timm.layers import trunc_normal_, create_classifier, Linear +from ._builder import build_model_with_cfg +from ._builder import resolve_pretrained_cfg +from ._manipulate import flatten_modules +from ._registry import register_model + +__all__ = ['InceptionV3', 'InceptionV3Aux'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 5f4e208f..c1559829 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -7,9 +7,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['InceptionV4'] diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 21c641b6..1bfb95d5 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1,44 +1,45 @@ -from .activations import * -from .adaptive_avgmax_pool import \ +# NOTE timm.models.layers is DEPRECATED, please use timm.layers, this is here to reduce breakages in transition +from timm.layers.activations import * +from timm.layers.adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d -from .blur_pool import BlurPool2d -from .classifier import ClassifierHead, create_classifier -from .cond_conv2d import CondConv2d, get_condconv_initializer -from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ +from timm.layers.blur_pool import BlurPool2d +from timm.layers.classifier import ClassifierHead, create_classifier +from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer +from timm.layers.config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ set_layer_config -from .conv2d_same import Conv2dSame, conv2d_same -from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct -from .create_act import create_act_layer, get_act_layer, get_act_fn -from .create_attn import get_attn, create_attn -from .create_conv2d import create_conv2d -from .create_norm import get_norm_layer, create_norm_layer -from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer -from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path -from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn -from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ +from timm.layers.conv2d_same import Conv2dSame, conv2d_same +from timm.layers.conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct +from timm.layers.create_act import create_act_layer, get_act_layer, get_act_fn +from timm.layers.create_attn import get_attn, create_attn +from timm.layers.create_conv2d import create_conv2d +from timm.layers.create_norm import get_norm_layer, create_norm_layer +from timm.layers.create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer +from timm.layers.drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from timm.layers.eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from timm.layers.evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a -from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm -from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d -from .gather_excite import GatherExcite -from .global_context import GlobalContext -from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple -from .inplace_abn import InplaceAbn -from .linear import Linear -from .mixed_conv2d import MixedConv2d -from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp -from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d -from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm -from .padding import get_padding, get_same_padding, pad_same -from .patch_embed import PatchEmbed -from .pool2d_same import AvgPool2dSame, create_pool2d -from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite -from .selective_kernel import SelectiveKernel -from .separable_conv import SeparableConv2d, SeparableConvNormAct -from .space_to_depth import SpaceToDepthModule -from .split_attn import SplitAttn -from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model -from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame -from .test_time_pool import TestTimePoolHead, apply_test_time_pool -from .trace_utils import _assert, _float_to_int -from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ +from timm.layers.fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm +from timm.layers.filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d +from timm.layers.gather_excite import GatherExcite +from timm.layers.global_context import GlobalContext +from timm.layers.helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple +from timm.layers.inplace_abn import InplaceAbn +from timm.layers.linear import Linear +from timm.layers.mixed_conv2d import MixedConv2d +from timm.layers.mlp import Mlp, GluMlp, GatedMlp, ConvMlp +from timm.layers.non_local_attn import NonLocalAttn, BatNonLocalAttn +from timm.layers.norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d +from timm.layers.norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm +from timm.layers.padding import get_padding, get_same_padding, pad_same +from timm.layers.patch_embed import PatchEmbed +from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d +from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from timm.layers.selective_kernel import SelectiveKernel +from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct +from timm.layers.space_to_depth import SpaceToDepthModule +from timm.layers.split_attn import SplitAttn +from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool +from timm.layers.trace_utils import _assert, _float_to_int +from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/models/levit.py b/timm/models/levit.py index cea9f0fc..8dc11309 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -23,8 +23,6 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Modified from # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Copyright 2020 Ross Wightman, Apache-2.0 License -import itertools -from copy import deepcopy from functools import partial from typing import Dict @@ -32,10 +30,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import to_ntuple, get_act_layer -from .vision_transformer import trunc_normal_ -from .registry import register_model +from timm.layers import to_ntuple, get_act_layer, trunc_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['LevitDistilled'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 3f315093..1e2666e5 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -45,17 +45,17 @@ from typing import Callable, Optional, Union, Tuple, List import torch from torch import nn -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq, named_apply -from .fx_features import register_notrace_function -from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm -from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d -from .layers import SelectAdaptivePool2d, create_pool2d -from .layers import to_2tuple, extend_tuple, make_divisible, _assert -from .pretrained import generate_default_cfgs -from .registry import register_model +from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm +from timm.layers import SelectAdaptivePool2d, create_pool2d +from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d +from timm.layers import to_2tuple, extend_tuple, make_divisible, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply, checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location __all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index a77e2eb7..a7825899 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -39,16 +39,18 @@ A thank you to paper authors for releasing code and weights. Hacked together by / Copyright 2021 Ross Wightman """ import math -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model + +__all__ = ['MixerBlock'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index bb72ccb8..cf4f268d 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -14,13 +14,14 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ +from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq -from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer -from .registry import register_model +from ._features import FeatureInfo, FeatureHooks +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['MobileNetV3', 'MobileNetV3Features'] diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index bd5479a7..3d2ae84a 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -14,18 +14,18 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022 # Copyright (C) 2020 Apple Inc. All Rights Reserved. # import math -from typing import Union, Callable, Dict, Tuple, Optional, Sequence +from typing import Callable, Tuple, Optional import torch -from torch import nn import torch.nn.functional as F +from torch import nn +from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups -from .fx_features import register_notrace_module -from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath from .vision_transformer import Block as TransformerBlock -from .helpers import build_model_with_cfg -from .registry import register_model __all__ = [] diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index c5aaa09e..5c0a6650 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -24,10 +24,12 @@ import torch.utils.checkpoint as checkpoint from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg -from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple -from .registry import register_model +from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model + +__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 50db1a3d..0b2178d6 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier -from .registry import register_model +from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['NASNetALarge'] diff --git a/timm/models/nest.py b/timm/models/nest.py index 8692a2b1..c9c6258c 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -25,12 +25,14 @@ import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ -from .layers import _assert -from .layers import create_conv2d, create_pool2d, to_ntuple -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert +from timm.layers import create_conv2d, create_pool2d, to_ntuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model + +__all__ = ['Nest'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 3a45410b..48f91b35 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -16,21 +16,23 @@ Status: Hacked together by / copyright Ross Wightman, 2021. """ -import math -from dataclasses import dataclass, field from collections import OrderedDict -from typing import Tuple, Optional +from dataclasses import dataclass from functools import partial +from typing import Tuple, Optional import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module -from .helpers import build_model_with_cfg, checkpoint_seq -from .registry import register_model -from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ +from timm.layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame, \ get_act_layer, get_act_fn, get_attn, make_divisible +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['NormFreeNet', 'NfCfg'] # model_registry will add each entrypoint fn to this def _dcfg(url='', **kwargs): diff --git a/timm/models/pit.py b/timm/models/pit.py index 0f571319..4f40e5e0 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -13,7 +13,6 @@ Modifications for timm by / Copyright 2020 Ross Wightman import math import re -from copy import deepcopy from functools import partial from typing import Tuple @@ -21,12 +20,15 @@ import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import trunc_normal_, to_2tuple -from .registry import register_model +from timm.layers import trunc_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._registry import register_model from .vision_transformer import Block +__all__ = ['PoolingVisionTransformer'] # model_registry will add each entrypoint fn to this + + def _cfg(url='', **kwargs): return { 'url': url, diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 81067845..7291c8fb 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -12,9 +12,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier -from .registry import register_model +from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['PNASNet5Large'] diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index 09359bc8..b4d2d18f 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -19,15 +19,15 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import copy import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 -from .registry import register_model +from timm.layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['PoolFormer'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index dd3cf690..696a2506 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -24,9 +24,9 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ -from .registry import register_model +from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['PyramidVisionTransformerV2'] diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 0ad7c826..e1cc821b 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -23,10 +23,13 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct -from .layers import get_act_layer, get_norm_act_layer, create_conv2d -from .registry import register_model +from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct +from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model + +__all__ = ['RegNet', 'RegNetCfg'] # model_registry will add each entrypoint fn to this @dataclass diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 6c2fd1bf..4724df2a 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -8,8 +8,8 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .registry import register_model +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet __all__ = [] diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 735b91a2..3b001c7b 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -6,13 +6,12 @@ Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang198 Modified for torchscript compat, and consistency with timm by Ross Wightman """ -import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SplitAttn -from .registry import register_model +from timm.layers import SplitAttn +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet diff --git a/timm/models/resnet.py b/timm/models/resnet.py index d0d98894..50849017 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -15,9 +15,11 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier -from .registry import register_model +from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \ + create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, model_entrypoint __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this @@ -675,6 +677,11 @@ class ResNet(nn.Module): self.init_weights(zero_init_last=zero_init_last) + @staticmethod + def from_pretrained(model_name: str, load_weights=True, **kwargs) -> 'ResNet': + entry_fn = model_entrypoint(model_name, 'resnet') + return entry_fn(pretrained=not load_weights, **kwargs) + @torch.jit.ignore def init_weights(self, zero_init_last=True): for n, m in self.named_modules(): @@ -822,7 +829,7 @@ def resnet50(pretrained=False, **kwargs): @register_model -def resnet50d(pretrained=False, **kwargs): +def resnet50d(pretrained=False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model. """ model_args = dict( diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index b21ef7f5..f8c4298b 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -30,16 +30,19 @@ Original copyright of Google code below, modifications by Ross Wightman, Copyrig # limitations under the License. from collections import OrderedDict # pylint: disable=g-importing-member +from functools import partial import torch import torch.nn as nn -from functools import partial from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq -from .registry import register_model -from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, EvoNorm2dS1, FilterResponseNormTlu2d,\ +from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \ ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv +from ._registry import register_model + +__all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 33e97222..51e8cdc2 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -10,16 +10,20 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe Copyright 2020 Ross Wightman """ -import torch -import torch.nn as nn from functools import partial from math import ceil +import torch +import torch.nn as nn + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule -from .registry import register_model -from .efficientnet_builder import efficientnet_init_weights +from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule +from ._builder import build_model_with_cfg +from ._efficientnet_builder import efficientnet_init_weights +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['ReXNetV1'] # model_registry will add each entrypoint fn to this def _cfg(url=''): diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 1a9ac929..4d40c49a 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -16,9 +16,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/senet.py b/timm/models/senet.py index a9e23ff1..d36e9854 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -19,9 +19,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['SENet'] diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index b1ae92a4..f3f758b9 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -6,7 +6,6 @@ Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2 # Copyright (c) 2022. Yuki Tatsunami # Licensed under the Apache License, Version 2.0 (the "License"); - import math from functools import partial from typing import Tuple @@ -15,9 +14,12 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT -from .helpers import build_model_with_cfg, named_apply -from .layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed -from .registry import register_model +from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed +from ._builder import build_model_with_cfg +from ._manipulate import named_apply +from ._registry import register_model + +__all__ = ['Sequencer2D'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/sknet.py b/timm/models/sknet.py index fb9f063a..5a29b9a4 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -13,9 +13,9 @@ import math from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import SelectiveKernel, ConvNormAct, ConvNormActAa, create_attn -from .registry import register_model +from timm.layers import SelectiveKernel, ConvNormAct, create_attn +from ._builder import build_model_with_cfg +from ._registry import register_model from .resnet import ResNet diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index f2305fb2..5df06d4d 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -17,19 +17,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # -------------------------------------------------------- import logging import math -from functools import partial from typing import Optional import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq, named_apply +from ._registry import register_model from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit +__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 0c9db3dd..efaaa9e9 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -21,10 +21,12 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import register_model + +__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index d143c14c..cf10b39c 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -29,7 +29,6 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W # -------------------------------------------------------- import logging import math -from copy import deepcopy from typing import Tuple, Optional, List, Union, Any, Type import torch @@ -38,11 +37,13 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import DropPath, Mlp, to_2tuple, _assert -from .registry import register_model +from timm.layers import DropPath, Mlp, to_2tuple, _assert +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import named_apply +from ._registry import register_model +__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 5b72b196..50088baf 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -7,17 +7,18 @@ The official mindspore code is released and available at https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT """ import math + import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import build_model_with_cfg -from timm.models.layers import Mlp, DropPath, trunc_normal_ -from timm.models.layers.helpers import to_2tuple -from timm.models.layers import _assert -from timm.models.registry import register_model -from timm.models.vision_transformer import resize_pos_embed +from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple +from ._builder import build_model_with_cfg +from ._registry import register_model +from .vision_transformer import resize_pos_embed + +__all__ = ['TNT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 2469acd2..83cb0576 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -10,11 +10,11 @@ from collections import OrderedDict import torch import torch.nn as nn -from .helpers import build_model_with_cfg -from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule -from .registry import register_model +from timm.layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule +from ._builder import build_model_with_cfg +from ._registry import register_model -__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] +__all__ = ['TResNet'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/twins.py b/timm/models/twins.py index 0626db37..41944c36 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -12,20 +12,21 @@ Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/li # Written by Xinjie Li, Xiangxiang Chu # -------------------------------------------------------- import math -from copy import deepcopy -from typing import Optional, Tuple +from functools import partial +from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F -from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ -from .fx_features import register_notrace_module -from .registry import register_model +from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .vision_transformer import Attention -from .helpers import build_model_with_cfg + +__all__ = ['Twins'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/vgg.py b/timm/models/vgg.py index caf96517..abe9f8d5 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -5,21 +5,19 @@ timm functionality. Copyright 2021 Ross Wightman """ +from typing import Union, List, Dict, Any, cast + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .fx_features import register_notrace_module -from .layers import ClassifierHead -from .registry import register_model - -__all__ = [ - 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', - 'vgg19_bn', 'vgg19', -] +from timm.layers import ClassifierHead +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model + +__all__ = ['VGG'] def _cfg(url='', **kwargs): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 254a0748..e15ae4a5 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -6,17 +6,15 @@ From original at https://github.com/danczs/Visformer Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman """ -from copy import deepcopy import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier -from .registry import register_model - +from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['Visformer'] diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 4effbed6..3c2ebc29 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -19,10 +19,10 @@ for some einops/einsum fun Hacked together by / Copyright 2020, Ross Wightman """ -import math import logging -from functools import partial +import math from collections import OrderedDict +from functools import partial from typing import Optional import torch @@ -30,12 +30,17 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\ +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ -from .pretrained import generate_default_cfgs -from .registry import register_model +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv +from ._pretrained import generate_default_cfgs +from ._registry import register_model + + +__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this + _logger = logging.getLogger(__name__) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 5e5113d7..cfdd0a0e 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -13,19 +13,18 @@ They were moved here to keep file sizes sane. Hacked together by / Copyright 2020, Ross Wightman """ -from copy import deepcopy from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import StdConv2dSame, StdConv2d, to_2tuple -from .pretrained import generate_default_cfgs +from timm.layers import StdConv2dSame, StdConv2d, to_2tuple +from ._pretrained import generate_default_cfgs +from ._registry import register_model from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem -from .registry import register_model -from timm.models.vision_transformer import _create_vision_transformer +from .vision_transformer import _create_vision_transformer def _cfg(url='', **kwargs): diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 52b3ce45..1a7c2f40 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -4,11 +4,9 @@ NOTE: these models are experimental / WIP, expect changes Hacked together by / Copyright 2022, Ross Wightman """ -import math import logging +import math from functools import partial -from collections import OrderedDict -from dataclasses import dataclass from typing import Optional, Tuple import torch @@ -16,10 +14,12 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply -from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple -from .registry import register_model +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) diff --git a/timm/models/volo.py b/timm/models/volo.py index 735453c8..1117995a 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -20,17 +20,19 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman # See the License for the specific language governing permissions and # limitations under the License. import math -import numpy as np +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ -from timm.models.registry import register_model -from timm.models.helpers import build_model_with_cfg +from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model + +__all__ = ['VOLO'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 39d37195..bf0e4f89 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -15,13 +15,15 @@ from typing import List import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .registry import register_model -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath,\ +from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \ create_attn, create_norm_act_layer, get_norm_act_layer +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model + +__all__ = ['VovNet'] # model_registry will add each entrypoint fn to this # model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 & diff --git a/timm/models/xception.py b/timm/models/xception.py index 99d02c46..99e74b46 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -25,9 +25,9 @@ import torch.jit import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from .layers import create_classifier -from .registry import register_model +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model __all__ = ['Xception'] diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 6bbce5e6..e3348e64 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -11,10 +11,11 @@ import torch import torch.nn as nn from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer -from .layers.helpers import to_3tuple -from .registry import register_model +from timm.layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer +from timm.layers.helpers import to_3tuple +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model __all__ = ['XceptionAligned'] diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 6802fc84..57c11183 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -19,12 +19,14 @@ import torch.nn as nn from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .vision_transformer import _cfg, Mlp -from .registry import register_model -from .layers import DropPath, trunc_normal_, to_2tuple +from timm.layers import DropPath, trunc_normal_, to_2tuple +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model from .cait import ClassAttn -from .fx_features import register_notrace_module +from .vision_transformer import Mlp + +__all__ = ['XCiT'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 02f0e250..8613a62c 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn import torch.optim as optim -from timm.models.helpers import group_parameters +from timm.models import group_parameters from .adabelief import AdaBelief from .adafactor import Adafactor diff --git a/timm/version.py b/timm/version.py index 0f19999f..0716d38a 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.8.0dev0' +__version__ = '0.8.1dev0' diff --git a/train.py b/train.py index d40ff04b..1276840d 100755 --- a/train.py +++ b/train.py @@ -31,10 +31,9 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm import utils from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \ - LabelSmoothingCrossEntropy -from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ - convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm +from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm +from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy +from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler @@ -82,7 +81,7 @@ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset parameters group = parser.add_argument_group('Dataset parameters') -# Keep this argument outside of the dataset group because it is positional. +# Keep this argument outside the dataset group because it is positional. parser.add_argument('data', nargs='?', metavar='DIR', const=None, help='path to dataset (positional is *deprecated*, use --data-dir)') parser.add_argument('--data-dir', metavar='DIR', diff --git a/validate.py b/validate.py index 6b8222b9..3bbf07cf 100755 --- a/validate.py +++ b/validate.py @@ -8,22 +8,24 @@ canonical PyTorch, standard Python style, and good performance. Repurpose as you Hacked together by Ross Wightman (https://github.com/rwightman) """ import argparse -import os import csv import glob import json -import time import logging -import torch -import torch.nn as nn -import torch.nn.parallel +import os +import time from collections import OrderedDict from contextlib import suppress from functools import partial -from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm +import torch +import torch.nn as nn +import torch.nn.parallel + from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet -from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\ +from timm.layers import apply_test_time_pool, set_fast_norm +from timm.models import create_model, load_checkpoint, is_model, list_models +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ decay_batch_step, check_batch_size_retry try: From 7c4ed4d5a43f46084cc9b6f20a5edb8839bbeb14 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 8 Dec 2022 16:20:49 -0800 Subject: [PATCH 2/9] Add EVA-large models --- README.md | 11 +++++++++ timm/models/vision_transformer.py | 37 +++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/README.md b/README.md index 994775f1..331ea7a8 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,17 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New +# Dec 8, 2022 +* Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some) + * original source: https://github.com/baaivision/EVA + +| model | top1 | param_count | gmac | macts | hub | +|:------------------------------------------|-----:|------------:|------:|------:|:----------------------------------------| +| eva_large_patch14_336.in22k_ft_in22k_in1k | 89.2 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_336.in22k_ft_in1k | 88.7 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_196.in22k_ft_in22k_in1k | 88.6 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_196.in22k_ft_in1k | 87.9 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | + # Dec 6, 2022 * Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`. * original source: https://github.com/baaivision/EVA diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 4effbed6..820dc656 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -933,6 +933,25 @@ default_cfgs = generate_default_cfgs({ 'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''), + + # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain + # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip + 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_large_patch14_196.in22k_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), }) @@ -1354,3 +1373,21 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs): patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def eva_large_patch14_196(pretrained=False, **kwargs): + """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" + model_kwargs = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) + model = _create_vision_transformer('eva_large_patch14_196', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def eva_large_patch14_336(pretrained=False, **kwargs): + """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" + model_kwargs = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) + model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs) + return model From 3d6bc42aa15b67883e3bf0f92df92fc7b74030b1 Mon Sep 17 00:00:00 2001 From: Lorenzo Baraldi Date: Fri, 9 Dec 2022 12:03:23 +0100 Subject: [PATCH 3/9] Put validation loss under amp_autocast Secured the loss evaluation under the amp, avoiding function to operate on float16 --- train.py | 16 ++++++++-------- validate.py | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/train.py b/train.py index d40ff04b..b85eb6b0 100755 --- a/train.py +++ b/train.py @@ -970,16 +970,16 @@ def validate( with amp_autocast(): output = model(input) - if isinstance(output, (tuple, list)): - output = output[0] + if isinstance(output, (tuple, list)): + output = output[0] - # augmentation reduction - reduce_factor = args.tta - if reduce_factor > 1: - output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) - target = target[0:target.size(0):reduce_factor] + # augmentation reduction + reduce_factor = args.tta + if reduce_factor > 1: + output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) + target = target[0:target.size(0):reduce_factor] - loss = loss_fn(output, target) + loss = loss_fn(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) if args.distributed: diff --git a/validate.py b/validate.py index 6b8222b9..872f27b0 100755 --- a/validate.py +++ b/validate.py @@ -294,9 +294,9 @@ def validate(args): with amp_autocast(): output = model(input) - if valid_labels is not None: - output = output[:, valid_labels] - loss = criterion(output, target) + if valid_labels is not None: + output = output[:, valid_labels] + loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) From 9e47d8ad5942a3e60ab12978c1aca5068c201929 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 9 Dec 2022 11:13:37 -0800 Subject: [PATCH 4/9] Update README.md --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 331ea7a8..a63c2b14 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,13 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New +# Survey: Feedback Appreciated + +For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing. + +If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts: +[**hf.co/oss-survey**](https://hf.co/oss-survey) + # Dec 8, 2022 * Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some) * original source: https://github.com/baaivision/EVA From 1733177c75a23d3f8b34ffe4c8c9316440bef323 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 9 Dec 2022 11:14:35 -0800 Subject: [PATCH 5/9] Update README.md --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a63c2b14..130b604c 100644 --- a/README.md +++ b/README.md @@ -21,14 +21,14 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New -# Survey: Feedback Appreciated +### Survey: Feedback Appreciated For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing. If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts: [**hf.co/oss-survey**](https://hf.co/oss-survey) -# Dec 8, 2022 +### Dec 8, 2022 * Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some) * original source: https://github.com/baaivision/EVA @@ -39,7 +39,7 @@ If you have a couple of minutes and want to participate in shaping the future of | eva_large_patch14_196.in22k_ft_in22k_in1k | 88.6 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | | eva_large_patch14_196.in22k_ft_in1k | 87.9 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | -# Dec 6, 2022 +### Dec 6, 2022 * Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`. * original source: https://github.com/baaivision/EVA * paper: https://arxiv.org/abs/2211.07636 @@ -51,7 +51,7 @@ If you have a couple of minutes and want to participate in shaping the future of | eva_giant_patch14_336.clip_ft_in1k | 89.4 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) | | eva_giant_patch14_224.clip_ft_in1k | 89.1 | 1012.6 | 267.2 | 192.6 | [link](https://huggingface.co/BAAI/EVA) | -# Dec 5, 2022 +### Dec 5, 2022 * Pre-release (`0.8.0dev0`) of multi-weight support (`model_arch.pretrained_tag`). Install with `pip install --pre timm` * vision_transformer, maxvit, convnext are the first three model impl w/ support From 0fe90449e5e07e91a78ea847b76eda6010b55283 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 9 Dec 2022 11:21:46 -0800 Subject: [PATCH 6/9] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 130b604c..bb6485c0 100644 --- a/README.md +++ b/README.md @@ -21,12 +21,12 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New -### Survey: Feedback Appreciated +### 🤗 Survey: Feedback Appreciated 🤗 For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing. If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts: -[**hf.co/oss-survey**](https://hf.co/oss-survey) +[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏 ### Dec 8, 2022 * Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some) From cda39b35bd7ac4a8053f422802ba65f88dbb6e3c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 9 Dec 2022 14:39:45 -0800 Subject: [PATCH 7/9] Add a deprecation phase to module re-org --- benchmark.py | 3 ++- timm/models/_builder.py | 4 ++++ timm/models/_factory.py | 3 +++ timm/models/_features.py | 3 +++ timm/models/_features_fx.py | 4 ++++ timm/models/_helpers.py | 2 ++ timm/models/_hub.py | 3 +++ timm/models/_manipulate.py | 3 +++ timm/models/_pretrained.py | 3 +++ timm/models/_prune.py | 2 ++ timm/models/_registry.py | 2 +- timm/models/factory.py | 4 ++++ timm/models/features.py | 4 ++++ timm/models/fx_features.py | 4 ++++ timm/models/helpers.py | 7 +++++++ timm/models/hub.py | 4 ++++ timm/models/layers/__init__.py | 3 +++ timm/models/registry.py | 4 ++++ 18 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 timm/models/factory.py create mode 100644 timm/models/features.py create mode 100644 timm/models/fx_features.py create mode 100644 timm/models/helpers.py create mode 100644 timm/models/hub.py create mode 100644 timm/models/registry.py diff --git a/benchmark.py b/benchmark.py index 04557a7d..95e2cb5a 100755 --- a/benchmark.py +++ b/benchmark.py @@ -19,7 +19,8 @@ import torch.nn as nn import torch.nn.parallel from timm.data import resolve_data_config -from timm.models import create_model, is_model, list_models, set_fast_norm +from timm.layers import set_fast_norm +from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry diff --git a/timm/models/_builder.py b/timm/models/_builder.py index c99c85f6..f634650e 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -23,6 +23,10 @@ _DOWNLOAD_PROGRESS = False _CHECK_HASH = False +__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained', + 'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg'] + + def _resolve_pretrained_source(pretrained_cfg): cfg_source = pretrained_cfg.get('source', '') pretrained_url = pretrained_cfg.get('url', None) diff --git a/timm/models/_factory.py b/timm/models/_factory.py index 2b050ad6..a8092419 100644 --- a/timm/models/_factory.py +++ b/timm/models/_factory.py @@ -9,6 +9,9 @@ from ._hub import load_model_config_from_hf from ._registry import is_model, model_entrypoint +__all__ = ['parse_model_name', 'safe_model_name', 'create_model'] + + def parse_model_name(model_name): if model_name.startswith('hf_hub'): # NOTE for backwards compat, deprecate hf_hub use diff --git a/timm/models/_features.py b/timm/models/_features.py index 0bc46419..59b080cd 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -17,6 +17,9 @@ import torch import torch.nn as nn +__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] + + class FeatureInfo: def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index 2d4a33c2..10670a1d 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -35,6 +35,10 @@ except ImportError: pass +__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor', + 'FeatureGraphNet', 'GraphExtractNet'] + + def register_notrace_module(module: Type[nn.Module]): """ Any module not under timm.models.layers should get this decorator if we don't want to trace through it. diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py index 2856842d..995292aa 100644 --- a/timm/models/_helpers.py +++ b/timm/models/_helpers.py @@ -12,6 +12,8 @@ import timm.models._builder _logger = logging.getLogger(__name__) +__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint'] + def clean_state_dict(state_dict): # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 2a87ae7e..e6b7d558 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -31,6 +31,9 @@ except ImportError: _logger = logging.getLogger(__name__) +__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf', + 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub'] + def get_cache_dir(child_dir=''): """ diff --git a/timm/models/_manipulate.py b/timm/models/_manipulate.py index 82a922a2..192979fc 100644 --- a/timm/models/_manipulate.py +++ b/timm/models/_manipulate.py @@ -9,6 +9,9 @@ import torch from torch import nn as nn from torch.utils.checkpoint import checkpoint +__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv', + 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq'] + def model_parameters(model, exclude_head=False): if exclude_head: diff --git a/timm/models/_pretrained.py b/timm/models/_pretrained.py index 60f38fd4..c422dab7 100644 --- a/timm/models/_pretrained.py +++ b/timm/models/_pretrained.py @@ -4,6 +4,9 @@ from dataclasses import dataclass, field, replace, asdict from typing import Any, Deque, Dict, Tuple, Optional, Union +__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs'] + + @dataclass class PretrainedCfg: """ diff --git a/timm/models/_prune.py b/timm/models/_prune.py index 0d744e40..4e744dec 100644 --- a/timm/models/_prune.py +++ b/timm/models/_prune.py @@ -5,6 +5,8 @@ from torch import nn as nn from timm.layers import Conv2dSame, BatchNormAct2d, Linear +__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file'] + def extract_layer(model, layer): layer = layer.split('.') diff --git a/timm/models/_registry.py b/timm/models/_registry.py index 97c8fd59..fc7b3437 100644 --- a/timm/models/_registry.py +++ b/timm/models/_registry.py @@ -12,7 +12,7 @@ from typing import List, Optional, Union, Tuple from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag __all__ = [ - 'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', + 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] _module_to_models = defaultdict(set) # dict of sets to check membership of model in module diff --git a/timm/models/factory.py b/timm/models/factory.py new file mode 100644 index 00000000..0ae83dc0 --- /dev/null +++ b/timm/models/factory.py @@ -0,0 +1,4 @@ +from ._factory import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/features.py b/timm/models/features.py new file mode 100644 index 00000000..25605d99 --- /dev/null +++ b/timm/models/features.py @@ -0,0 +1,4 @@ +from ._features import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py new file mode 100644 index 00000000..0ff3a18b --- /dev/null +++ b/timm/models/fx_features.py @@ -0,0 +1,4 @@ +from ._features_fx import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/helpers.py b/timm/models/helpers.py new file mode 100644 index 00000000..6bc82eb8 --- /dev/null +++ b/timm/models/helpers.py @@ -0,0 +1,7 @@ +from ._builder import * +from ._helpers import * +from ._manipulate import * +from ._prune import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/hub.py b/timm/models/hub.py new file mode 100644 index 00000000..074ca025 --- /dev/null +++ b/timm/models/hub.py @@ -0,0 +1,4 @@ +from _hub import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 1bfb95d5..97e70563 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -43,3 +43,6 @@ from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, Scal from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool from timm.layers.trace_utils import _assert, _float_to_int from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/timm/models/registry.py b/timm/models/registry.py new file mode 100644 index 00000000..58e2e1f4 --- /dev/null +++ b/timm/models/registry.py @@ -0,0 +1,4 @@ +from ._registry import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) From e3b2f5be0afaa6f2dd71be17a689b44b126a3ce9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Dec 2022 16:25:50 -0800 Subject: [PATCH 8/9] Add 3-Augment support to auto_augment.py, clean up weighted choice handling, and allow adjust per op prob via arg string --- timm/data/auto_augment.py | 325 ++++++++++++++++++++++---------- timm/data/transforms_factory.py | 10 +- 2 files changed, 236 insertions(+), 99 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 1b51ccb4..e461f67c 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -1,4 +1,4 @@ -""" AutoAugment, RandAugment, and AugMix for PyTorch +""" AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch This code implements the searched ImageNet policies with various tweaks and improvements and does not include any of the search code. @@ -9,18 +9,24 @@ AA and RA Implementation adapted from: AugMix adapted from: https://github.com/google-research/augmix +3-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md + Papers: AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 + 3-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118 Hacked together by / Copyright 2019, Ross Wightman """ import random import math import re -from PIL import Image, ImageOps, ImageEnhance, ImageChops +from functools import partial +from typing import Dict, List, Optional, Union + +from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter import PIL import numpy as np @@ -175,6 +181,24 @@ def sharpness(img, factor, **__): return ImageEnhance.Sharpness(img).enhance(factor) +def gaussian_blur(img, factor, **__): + img = img.filter(ImageFilter.GaussianBlur(radius=factor)) + return img + + +def gaussian_blur_rand(img, factor, **__): + radius_min = 0.1 + radius_max = 2.0 + img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor))) + return img + + +def desaturate(img, factor, **_): + factor = min(1., max(0., 1. - factor)) + # enhance factor 0 = grayscale, 1.0 = no-change + return ImageEnhance.Color(img).enhance(factor) + + def _randomly_negate(v): """With 50% prob, negate the value""" return -v if random.random() > 0.5 else v @@ -200,6 +224,14 @@ def _enhance_increasing_level_to_arg(level, _hparams): return level, +def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True): + level = (level / _LEVEL_DENOM) + min_val + (max_val - min_val) * level + if clamp: + level = min(min_val, max(max_val, level)) + return level, + + def _shear_level_to_arg(level, _hparams): # range [-0.3, 0.3] level = (level / _LEVEL_DENOM) * 0.3 @@ -246,7 +278,7 @@ def _posterize_original_level_to_arg(level, _hparams): def _solarize_level_to_arg(level, _hparams): # range [0, 256] # intensity/severity of augmentation decreases with level - return int((level / _LEVEL_DENOM) * 256), + return min(256, int((level / _LEVEL_DENOM) * 256)), def _solarize_increasing_level_to_arg(level, _hparams): @@ -257,7 +289,7 @@ def _solarize_increasing_level_to_arg(level, _hparams): def _solarize_add_level_to_arg(level, _hparams): # range [0, 110] - return int((level / _LEVEL_DENOM) * 110), + return min(128, int((level / _LEVEL_DENOM) * 110)), LEVEL_TO_ARG = { @@ -286,6 +318,9 @@ LEVEL_TO_ARG = { 'TranslateY': _translate_abs_level_to_arg, 'TranslateXRel': _translate_rel_level_to_arg, 'TranslateYRel': _translate_rel_level_to_arg, + 'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0), + 'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0), + 'GaussianBlurRand': _minmax_level_to_arg, } @@ -314,6 +349,9 @@ NAME_TO_OP = { 'TranslateY': translate_y_abs, 'TranslateXRel': translate_x_rel, 'TranslateYRel': translate_y_rel, + 'Desaturate': desaturate, + 'GaussianBlur': gaussian_blur, + 'GaussianBlurRand': gaussian_blur_rand, } @@ -347,6 +385,7 @@ class AugmentOp: if self.magnitude_std > 0: # magnitude randomization enabled if self.magnitude_std == float('inf'): + # inf == uniform sampling magnitude = random.uniform(0, magnitude) elif self.magnitude_std > 0: magnitude = random.gauss(magnitude, self.magnitude_std) @@ -499,6 +538,16 @@ def auto_augment_policy_originalr(hparams): return pc +def auto_augment_policy_3a(hparams): + policy = [ + [('Solarize', 1.0, 5)], # 128 solarize threshold @ 5 magnitude + [('Desaturate', 1.0, 10)], # grayscale at 10 magnitude + [('GaussianBlurRand', 1.0, 10)], + ] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + def auto_augment_policy(name='v0', hparams=None): hparams = hparams or _HPARAMS_DEFAULT if name == 'original': @@ -509,6 +558,8 @@ def auto_augment_policy(name='v0', hparams=None): return auto_augment_policy_v0(hparams) elif name == 'v0r': return auto_augment_policy_v0r(hparams) + elif name == '3a': + return auto_augment_policy_3a(hparams) else: assert False, 'Unknown AA policy (%s)' % name @@ -534,19 +585,23 @@ class AutoAugment: return fs -def auto_augment_transform(config_str, hparams): +def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None): """ Create a AutoAugment transform - :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by - dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). - The remaining sections, not order sepecific determine - 'mstd' - float std deviation of magnitude noise applied - Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 + Args: + config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by + dashes ('-'). + The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). - :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme + The remaining sections: + 'mstd' - float std deviation of magnitude noise applied + Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 - :return: A PyTorch compatible Transform + hparams: Other hparams (kwargs) for the AutoAugmentation scheme + + Returns: + A PyTorch compatible Transform """ config = config_str.split('-') policy_name = config[0] @@ -605,42 +660,80 @@ _RAND_INCREASING_TRANSFORMS = [ ] +_RAND_3A = [ + 'SolarizeIncreasing', + 'Desaturate', + 'GaussianBlur', +] + + +_RAND_CHOICE_3A = { + 'SolarizeIncreasing': 6, + 'Desaturate': 6, + 'GaussianBlur': 6, + 'Rotate': 3, + 'ShearX': 2, + 'ShearY': 2, + 'PosterizeIncreasing': 1, + 'AutoContrast': 1, + 'ColorIncreasing': 1, + 'SharpnessIncreasing': 1, + 'ContrastIncreasing': 1, + 'BrightnessIncreasing': 1, + 'Equalize': 1, + 'Invert': 1, +} + # These experimental weights are based loosely on the relative improvements mentioned in paper. # They may not result in increased performance, but could likely be tuned to so. _RAND_CHOICE_WEIGHTS_0 = { - 'Rotate': 0.3, - 'ShearX': 0.2, - 'ShearY': 0.2, - 'TranslateXRel': 0.1, - 'TranslateYRel': 0.1, - 'Color': .025, - 'Sharpness': 0.025, - 'AutoContrast': 0.025, - 'Solarize': .005, - 'SolarizeAdd': .005, - 'Contrast': .005, - 'Brightness': .005, - 'Equalize': .005, - 'Posterize': 0, - 'Invert': 0, + 'Rotate': 3, + 'ShearX': 2, + 'ShearY': 2, + 'TranslateXRel': 1, + 'TranslateYRel': 1, + 'ColorIncreasing': .25, + 'SharpnessIncreasing': 0.25, + 'AutoContrast': 0.25, + 'SolarizeIncreasing': .05, + 'SolarizeAdd': .05, + 'ContrastIncreasing': .05, + 'BrightnessIncreasing': .05, + 'Equalize': .05, + 'PosterizeIncreasing': 0.05, + 'Invert': 0.05, } -def _select_rand_weights(weight_idx=0, transforms=None): - transforms = transforms or _RAND_TRANSFORMS - assert weight_idx == 0 # only one set of weights currently - rand_weights = _RAND_CHOICE_WEIGHTS_0 - probs = [rand_weights[k] for k in transforms] - probs /= np.sum(probs) - return probs +def _get_weighted_transforms(transforms: Dict): + transforms, probs = list(zip(*transforms.items())) + probs = np.array(probs) + probs = probs / np.sum(probs) + return transforms, probs + +def rand_augment_choices(name: str, increasing=True): + if name == 'weights': + return _RAND_CHOICE_WEIGHTS_0 + elif name == '3aw': + return _RAND_CHOICE_3A + elif name == '3a': + return _RAND_3A + else: + return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS -def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + +def rand_augment_ops( + magnitude: Union[int, float] = 10, + prob: float = 0.5, + hparams: Optional[Dict] = None, + transforms: Optional[Union[Dict, List]] = None, +): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS return [AugmentOp( - name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms] class RandAugment: @@ -648,11 +741,16 @@ class RandAugment: self.ops = ops self.num_layers = num_layers self.choice_weights = choice_weights + print(self.ops, self.choice_weights) def __call__(self, img): # no replacement when using weighted choice ops = np.random.choice( - self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) + self.ops, + self.num_layers, + replace=self.choice_weights is None, + p=self.choice_weights, + ) for op in ops: img = op(img) return img @@ -665,61 +763,84 @@ class RandAugment: return fs -def rand_augment_transform(config_str, hparams): +def rand_augment_transform( + config_str: str, + hparams: Optional[Dict] = None, + transforms: Optional[Union[str, Dict, List]] = None, +): """ Create a RandAugment transform - :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by - dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining - sections, not order sepecific determine - 'm' - integer magnitude of rand augment - 'n' - integer num layers (number of transform ops selected per image) - 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) - 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) - 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) - 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) - Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 - 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 - - :param hparams: Other hparams (kwargs) for the RandAugmentation scheme - - :return: A PyTorch compatible Transform + Args: + config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated + by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). + The remaining sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'p' - float probability of applying each layer (default 0.5) + 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) + 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + 't' - str name of transform set to use + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2 + + hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme + + Returns: + A PyTorch compatible Transform """ magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10) num_layers = 2 # default to 2 ops per image - weight_idx = None # default to no probability weights for op choice - transforms = _RAND_TRANSFORMS + increasing = False + prob = 0.5 config = config_str.split('-') assert config[0] == 'rand' config = config[1:] for c in config: - cs = re.split(r'(\d.*)', c) - if len(cs) < 2: - continue - key, val = cs[:2] - if key == 'mstd': - # noise param / randomization of magnitude values - mstd = float(val) - if mstd > 100: - # use uniform sampling in 0 to magnitude if mstd is > 100 - mstd = float('inf') - hparams.setdefault('magnitude_std', mstd) - elif key == 'mmax': - # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM] - hparams.setdefault('magnitude_max', int(val)) - elif key == 'inc': - if bool(val): - transforms = _RAND_INCREASING_TRANSFORMS - elif key == 'm': - magnitude = int(val) - elif key == 'n': - num_layers = int(val) - elif key == 'w': - weight_idx = int(val) + if c.startswith('t'): + # NOTE old 'w' key was removed, 'w0' is not equivalent to 'tweights' + val = str(c[1:]) + if transforms is None: + transforms = val else: - assert False, 'Unknown RandAugment config section' - ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms) - choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + # numeric options + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param / randomization of magnitude values + mstd = float(val) + if mstd > 100: + # use uniform sampling in 0 to magnitude if mstd is > 100 + mstd = float('inf') + hparams.setdefault('magnitude_std', mstd) + elif key == 'mmax': + # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM] + hparams.setdefault('magnitude_max', int(val)) + elif key == 'inc': + if bool(val): + increasing = True + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + elif key == 'p': + prob = float(val) + else: + assert False, 'Unknown RandAugment config section' + + if isinstance(transforms, str): + transforms = rand_augment_choices(transforms, increasing=increasing) + elif transforms is None: + transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS + + choice_weights = None + if isinstance(transforms, Dict): + transforms, choice_weights = _get_weighted_transforms(transforms) + + ra_ops = rand_augment_ops(magnitude=magnitude, prob=prob, hparams=hparams, transforms=transforms) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) @@ -740,11 +861,19 @@ _AUGMIX_TRANSFORMS = [ ] -def augmix_ops(magnitude=10, hparams=None, transforms=None): +def augmix_ops( + magnitude: Union[int, float] = 10, + hparams: Optional[Dict] = None, + transforms: Optional[Union[str, Dict, List]] = None, +): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _AUGMIX_TRANSFORMS return [AugmentOp( - name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] + name, + prob=1.0, + magnitude=magnitude, + hparams=hparams + ) for name in transforms] class AugMixAugment: @@ -820,22 +949,24 @@ class AugMixAugment: return fs -def augment_and_mix_transform(config_str, hparams): +def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None): """ Create AugMix PyTorch transform - :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by - dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining - sections, not order sepecific determine - 'm' - integer magnitude (severity) of augmentation mix (default: 3) - 'w' - integer width of augmentation chain (default: 3) - 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) - 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) - 'mstd' - float std deviation of magnitude noise applied (default: 0) - Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 - - :param hparams: Other hparams (kwargs) for the Augmentation transforms - - :return: A PyTorch compatible Transform + Args: + config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated + by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). + The remaining sections, not order sepecific determine + 'm' - integer magnitude (severity) of augmentation mix (default: 3) + 'w' - integer width of augmentation chain (default: 3) + 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) + 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) + 'mstd' - float std deviation of magnitude noise applied (default: 0) + Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 + + hparams: Other hparams (kwargs) for the Augmentation transforms + + Returns: + A PyTorch compatible Transform """ magnitude = 3 width = 3 diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 6c28383a..7749b206 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -59,6 +59,7 @@ def transforms_imagenet_train( re_count=1, re_num_splits=0, separate=False, + force_color_jitter=False, ): """ If separate==True, the transforms are returned as a tuple of 3 separate transforms @@ -77,8 +78,12 @@ def transforms_imagenet_train( primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] secondary_tfl = [] + disable_color_jitter = False if auto_augment: assert isinstance(auto_augment, str) + # color jitter is typically disabled if AA/RA on, + # this allows override without breaking old hparm cfgs + disable_color_jitter = not (force_color_jitter or '3a' in auto_augment) if isinstance(img_size, (tuple, list)): img_size_min = min(img_size) else: @@ -96,8 +101,9 @@ def transforms_imagenet_train( secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] else: secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] - elif color_jitter is not None: - # color jitter is enabled when not using AA + + if color_jitter is not None and not disable_color_jitter: + # color jitter is enabled when not using AA or when forced if isinstance(color_jitter, (list, tuple)): # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # or 4 if also augmenting hue From e7da205345dcf770ee4bedd62d06fad7a1458904 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Dec 2022 16:43:28 -0800 Subject: [PATCH 9/9] Fix aa min_max level clamp --- timm/data/auto_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index e461f67c..a7701b82 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -228,7 +228,7 @@ def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True): level = (level / _LEVEL_DENOM) min_val + (max_val - min_val) * level if clamp: - level = min(min_val, max(max_val, level)) + level = max(min_val, min(max_val, level)) return level,