Major module / path restructure, timm.models.layers -> timm.layers, add _ prefix to all non model modules in timm.models

pull/1581/head
Ross Wightman 1 year ago
parent da6644b6ba
commit 927f031293

@ -16,7 +16,7 @@ import argparse
import os
import glob
import hashlib
from timm.models.helpers import load_state_dict
from timm.models import load_state_dict
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH',

@ -13,7 +13,7 @@ import os
import hashlib
import shutil
from collections import OrderedDict
from timm.models.helpers import load_state_dict
from timm.models import load_state_dict
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',

@ -1,4 +1,3 @@
dependencies = ['torch']
from timm.models import registry
globals().update(registry._model_entrypoints)
import timm
globals().update(timm.models._registry._model_entrypoints)

@ -5,11 +5,11 @@ An example inference script that outputs top-k class ids for images in a folder
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import os
import time
import argparse
import json
import logging
import os
import time
from contextlib import suppress
from functools import partial
@ -17,12 +17,11 @@ import numpy as np
import pandas as pd
import torch
from timm.models import create_model, apply_test_time_pool, load_checkpoint
from timm.data import create_dataset, create_loader, resolve_data_config
from timm.layers import apply_test_time_pool
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
try:
from apex import amp
has_apex = True

@ -1,10 +1,7 @@
import pytest
import torch
import torch.nn as nn
import platform
import os
from timm.models.layers import create_act_layer, get_act_layer, set_layer_config
from timm.layers import create_act_layer, set_layer_config
class MLP(nn.Module):

@ -14,7 +14,7 @@ except ImportError:
import timm
from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
from timm.models.fx_features import _leaf_modules, _autowrap_functions
from timm.models._features_fx import _leaf_modules, _autowrap_functions
if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests

@ -1,4 +1,4 @@
from .version import __version__
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable, \
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -1,6 +1,7 @@
import os
import pickle
def load_class_map(map_or_filename, root=''):
if isinstance(map_or_filename, dict):
assert dict, 'class_map dict must be non-empty'
@ -14,7 +15,7 @@ def load_class_map(map_or_filename, root=''):
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
elif class_map_ext == '.pkl':
with open(class_map_path,'rb') as f:
with open(class_map_path, 'rb') as f:
class_to_idx = pickle.load(f)
else:
assert False, f'Unsupported class map file extension ({class_map_ext}).'

@ -0,0 +1,44 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config
from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d
from .create_norm import get_norm_layer, create_norm_layer
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct
from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .trace_utils import _assert, _float_to_int
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_

@ -64,12 +64,18 @@ from .xception import *
from .xception_aligned import *
from .xcit import *
from .factory import create_model, parse_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .layers import set_fast_norm
from .pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\
from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \
set_pretrained_download_progress, set_pretrained_check_hash
from ._factory import create_model, parse_model_name, safe_model_name
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
register_notrace_module, register_notrace_function
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
from ._pretrained import PretrainedCfg, DefaultCfg, \
filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
from ._prune import adapt_model_from_string
from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -0,0 +1,395 @@
import dataclasses
import logging
from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple
from torch import nn as nn
from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
from timm.models._registry import get_pretrained_cfg
_logger = logging.getLogger(__name__)
# Global variables for rarely used pretrained checkpoint download progress and hash check.
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
_DOWNLOAD_PROGRESS = False
_CHECK_HASH = False
def _resolve_pretrained_source(pretrained_cfg):
cfg_source = pretrained_cfg.get('source', '')
pretrained_url = pretrained_cfg.get('url', None)
pretrained_file = pretrained_cfg.get('file', None)
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
# resolve where to load pretrained weights from
load_from = ''
pretrained_loc = ''
if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
# hf-hub specified as source via model identifier
load_from = 'hf-hub'
assert hf_hub_id
pretrained_loc = hf_hub_id
else:
# default source == timm or unspecified
if pretrained_file:
load_from = 'file'
pretrained_loc = pretrained_file
elif pretrained_url:
load_from = 'url'
pretrained_loc = pretrained_url
elif hf_hub_id and has_hf_hub(necessary=True):
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
# if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
return load_from, pretrained_loc
def set_pretrained_download_progress(enable=True):
""" Set download progress for pretrained weights on/off (globally). """
global _DOWNLOAD_PROGRESS
_DOWNLOAD_PROGRESS = enable
def set_pretrained_check_hash(enable=True):
""" Set hash checking for pretrained weights on/off (globally). """
global _CHECK_HASH
_CHECK_HASH = enable
def load_custom_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
load_fn: Optional[Callable] = None,
):
r"""Loads a custom (read non .pth) weight file
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
a passed in custom load fun, or the `load_pretrained` model member fn.
If the object is already present in `model_dir`, it's deserialized and returned.
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
Args:
model: The instantiated model to load weights into
pretrained_cfg (dict): Default pretrained model cfg
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
_logger.warning("Invalid pretrained config, cannot load weights.")
return
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if not load_from:
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
if load_from == 'hf-hub': # FIXME
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
elif load_from == 'url':
pretrained_loc = download_cached_file(
pretrained_loc,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS
)
if load_fn is not None:
load_fn(model, pretrained_loc)
elif hasattr(model, 'load_pretrained'):
model.load_pretrained(pretrained_loc)
else:
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
def load_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
num_classes: int = 1000,
in_chans: int = 3,
filter_fn: Optional[Callable] = None,
strict: bool = True,
):
""" Load pretrained checkpoint
Args:
model (nn.Module) : PyTorch model module
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
num_classes (int): num_classes for target model
in_chans (int): in_chans for target model
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
strict (bool): strict load of checkpoint
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
_logger.warning("Invalid pretrained config, cannot load weights.")
return
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if load_from == 'file':
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
state_dict = load_state_dict(pretrained_loc)
elif load_from == 'url':
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc)
else:
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
return
if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two
try:
state_dict = filter_fn(state_dict)
except TypeError:
state_dict = filter_fn(state_dict, model)
input_convs = pretrained_cfg.get('first_conv', None)
if input_convs is not None and in_chans != 3:
if isinstance(input_convs, str):
input_convs = (input_convs,)
for input_conv_name in input_convs:
weight_name = input_conv_name + '.weight'
try:
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
_logger.info(
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
except NotImplementedError as e:
del state_dict[weight_name]
strict = False
_logger.warning(
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
classifiers = pretrained_cfg.get('classifier', None)
label_offset = pretrained_cfg.get('label_offset', 0)
if classifiers is not None:
if isinstance(classifiers, str):
classifiers = (classifiers,)
if num_classes != pretrained_cfg['num_classes']:
for classifier_name in classifiers:
# completely discard fully connected if model num_classes doesn't match pretrained weights
state_dict.pop(classifier_name + '.weight', None)
state_dict.pop(classifier_name + '.bias', None)
strict = False
elif label_offset > 0:
for classifier_name in classifiers:
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
model.load_state_dict(state_dict, strict=strict)
def pretrained_cfg_for_features(pretrained_cfg):
pretrained_cfg = deepcopy(pretrained_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
for tr in to_remove:
pretrained_cfg.pop(tr, None)
return pretrained_cfg
def _filter_kwargs(kwargs, names):
if not kwargs or not names:
return
for n in names:
kwargs.pop(n, None)
def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
""" Update the default_cfg and kwargs before passing to model
Args:
pretrained_cfg: input pretrained cfg (updated in-place)
kwargs: keyword args passed to model build fn (updated in-place)
kwargs_filter: keyword arg keys that must be removed before model __init__
"""
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
if pretrained_cfg.get('fixed_input_size', False):
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
default_kwarg_names += ('img_size',)
for n in default_kwarg_names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# pretrained_cfg has one input_size=(C, H ,W) entry
if n == 'img_size':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[-2:])
elif n == 'in_chans':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[0])
else:
default_val = pretrained_cfg.get(n, None)
if default_val is not None:
kwargs.setdefault(n, pretrained_cfg[n])
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
_filter_kwargs(kwargs, names=kwargs_filter)
def resolve_pretrained_cfg(
variant: str,
pretrained_cfg=None,
pretrained_cfg_overlay=None,
) -> PretrainedCfg:
model_with_tag = variant
pretrained_tag = None
if pretrained_cfg:
if isinstance(pretrained_cfg, dict):
# pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
pretrained_cfg = PretrainedCfg(**pretrained_cfg)
elif isinstance(pretrained_cfg, str):
pretrained_tag = pretrained_cfg
pretrained_cfg = None
# fallback to looking up pretrained cfg in model registry by variant identifier
if not pretrained_cfg:
if pretrained_tag:
model_with_tag = '.'.join([variant, pretrained_tag])
pretrained_cfg = get_pretrained_cfg(model_with_tag)
if not pretrained_cfg:
_logger.warning(
f"No pretrained configuration specified for {model_with_tag} model. Using a default."
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
pretrained_cfg = PretrainedCfg() # instance with defaults
pretrained_cfg_overlay = pretrained_cfg_overlay or {}
if not pretrained_cfg.architecture:
pretrained_cfg_overlay.setdefault('architecture', variant)
pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)
return pretrained_cfg
def build_model_with_cfg(
model_cls: Callable,
variant: str,
pretrained: bool,
pretrained_cfg: Optional[Dict] = None,
pretrained_cfg_overlay: Optional[Dict] = None,
model_cfg: Optional[Any] = None,
feature_cfg: Optional[Dict] = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Optional[Callable] = None,
kwargs_filter: Optional[Tuple[str]] = None,
**kwargs,
):
""" Build model with specified default_cfg and optional model_cfg
This helper fn aids in the construction of a model including:
* handling default_cfg and associated pretrained weight loading
* passing through optional model_cfg for models with config based arch spec
* features_only model adaptation
* pruning config / model adaptation
Args:
model_cls (nn.Module): model class
variant (str): model variant name
pretrained (bool): load pretrained weights
pretrained_cfg (dict): model's pretrained weight/task config
model_cfg (Optional[Dict]): model's architecture config
feature_cfg (Optional[Dict]: feature extraction adapter config
pretrained_strict (bool): load pretrained weights strictly
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
**kwargs: model args passed through to model __init__
"""
pruned = kwargs.pop('pruned', False)
features = False
feature_cfg = feature_cfg or {}
# resolve and update model pretrained config and model kwargs
pretrained_cfg = resolve_pretrained_cfg(
variant,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay
)
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
pretrained_cfg = pretrained_cfg.to_dict()
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
# Setup for feature extraction wrapper done at end of this fn
if kwargs.pop('features_only', False):
features = True
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
# Instantiate the model
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat
if pruned:
model = adapt_model_from_file(model, variant)
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_cfg.get('custom_load', False):
load_custom_pretrained(
model,
pretrained_cfg=pretrained_cfg,
)
else:
load_pretrained(
model,
pretrained_cfg=pretrained_cfg,
num_classes=num_classes_pretrained,
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict,
)
# Wrap the model in a feature extraction module if enabled
if features:
feature_cls = FeatureListNet
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat
return model

@ -2,13 +2,12 @@
Hacked together by / Copyright 2019, Ross Wightman
"""
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from .layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer
__all__ = [
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']

@ -14,8 +14,8 @@ from functools import partial
import torch.nn as nn
from .efficientnet_blocks import *
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
from ._efficientnet_blocks import *
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']

@ -2,11 +2,11 @@ import os
from typing import Any, Dict, Optional, Union
from urllib.parse import urlsplit
from .pretrained import PretrainedCfg, split_model_name_tag
from .helpers import load_checkpoint
from .hub import load_model_config_from_hf
from .layers import set_layer_config
from .registry import is_model, model_entrypoint
from timm.layers import set_layer_config
from ._pretrained import PretrainedCfg, split_model_name_tag
from ._helpers import load_checkpoint
from ._hub import load_model_config_from_hf
from ._registry import is_model, model_entrypoint
def parse_model_name(model_name):

@ -6,7 +6,7 @@ from typing import Callable, List, Dict, Union, Type
import torch
from torch import nn
from .features import _get_feature_info
from ._features import _get_feature_info
try:
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
@ -15,9 +15,9 @@ except ImportError:
has_fx_feature_extraction = False
# Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from timm.layers.non_local_attn import BilinearAttnTransform
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
@ -29,7 +29,7 @@ _leaf_modules = {
}
try:
from .layers import InplaceAbn
from timm.layers import InplaceAbn
_leaf_modules.add(InplaceAbn)
except ImportError:
pass

@ -0,0 +1,113 @@
""" Model creation / weight loading / state_dict helpers
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import os
from collections import OrderedDict
import torch
import timm.models._builder
_logger = logging.getLogger(__name__)
def clean_state_dict(state_dict):
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
cleaned_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
cleaned_state_dict[name] = v
return cleaned_state_dict
def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = ''
if isinstance(checkpoint, dict):
if use_ema and checkpoint.get('state_dict_ema', None) is not None:
state_dict_key = 'state_dict_ema'
elif use_ema and checkpoint.get('model_ema', None) is not None:
state_dict_key = 'model_ema'
elif 'state_dict' in checkpoint:
state_dict_key = 'state_dict'
elif 'model' in checkpoint:
state_dict_key = 'model'
state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
return state_dict
else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'):
timm.models._model_builder.load_pretrained(checkpoint_path)
else:
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema)
if remap:
state_dict = remap_checkpoint(model, state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def remap_checkpoint(model, state_dict, allow_reshape=True):
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
This assumes models (and originating state dict) were created with params registered in same order.
"""
out_dict = {}
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
if va.shape != vb.shape:
if allow_reshape:
vb = vb.reshape(va.shape)
else:
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
out_dict[ka] = vb
return out_dict
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
if log_info:
_logger.info('Restoring model state from checkpoint...')
state_dict = clean_state_dict(checkpoint['state_dict'])
model.load_state_dict(state_dict)
if optimizer is not None and 'optimizer' in checkpoint:
if log_info:
_logger.info('Restoring optimizer state from checkpoint...')
optimizer.load_state_dict(checkpoint['optimizer'])
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
if log_info:
_logger.info('Restoring AMP loss scaler state from checkpoint...')
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
model.load_state_dict(checkpoint)
if log_info:
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
return resume_epoch
else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()

@ -15,7 +15,7 @@ except ImportError:
from torch.hub import _get_torch_home as get_dir
from timm import __version__
from timm.models.pretrained import filter_pretrained_cfg
from timm.models._pretrained import filter_pretrained_cfg
try:
from huggingface_hub import (

@ -0,0 +1,255 @@
import collections.abc
import math
import re
from collections import defaultdict
from itertools import chain
from typing import Callable, Union, Dict
import torch
from torch import nn as nn
from torch.utils.checkpoint import checkpoint
def model_parameters(model, exclude_head=False):
if exclude_head:
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
return [p for p in model.parameters()][:-2]
else:
return model.parameters()
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
if not depth_first and include_root:
yield name, module
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
yield from named_modules(
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
yield name, module
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
if module._parameters and not depth_first and include_root:
yield name, module
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
yield from named_modules_with_params(
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if module._parameters and depth_first and include_root:
yield name, module
MATCH_PREV_GROUP = (99999,)
def group_with_matcher(
named_objects,
group_matcher: Union[Dict, Callable],
output_values: bool = False,
reverse: bool = False
):
if isinstance(group_matcher, dict):
# dictionary matcher contains a dict of raw-string regex expr that must be compiled
compiled = []
for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
if mspec is None:
continue
# map all matching specifications into 3-tuple (compiled re, prefix, suffix)
if isinstance(mspec, (tuple, list)):
# multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
for sspec in mspec:
compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
else:
compiled += [(re.compile(mspec), (group_ordinal,), None)]
group_matcher = compiled
def _get_grouping(name):
if isinstance(group_matcher, (list, tuple)):
for match_fn, prefix, suffix in group_matcher:
r = match_fn.match(name)
if r:
parts = (prefix, r.groups(), suffix)
# map all tuple elem to int for numeric sort, filter out None entries
return tuple(map(float, chain.from_iterable(filter(None, parts))))
return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal
else:
ord = group_matcher(name)
if not isinstance(ord, collections.abc.Iterable):
return ord,
return tuple(ord)
# map layers into groups via ordinals (ints or tuples of ints) from matcher
grouping = defaultdict(list)
for k, v in named_objects:
grouping[_get_grouping(k)].append(v if output_values else k)
# remap to integers
layer_id_to_param = defaultdict(list)
lid = -1
for k in sorted(filter(lambda x: x is not None, grouping.keys())):
if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
lid += 1
layer_id_to_param[lid].extend(grouping[k])
if reverse:
assert not output_values, "reverse mapping only sensible for name output"
# output reverse mapping
param_to_layer_id = {}
for lid, lm in layer_id_to_param.items():
for n in lm:
param_to_layer_id[n] = lid
return param_to_layer_id
return layer_id_to_param
def group_parameters(
module: nn.Module,
group_matcher,
output_values=False,
reverse=False,
):
return group_with_matcher(
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
def group_modules(
module: nn.Module,
group_matcher,
output_values=False,
reverse=False,
):
return group_with_matcher(
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
prefix_is_tuple = isinstance(prefix, tuple)
if isinstance(module_types, str):
if module_types == 'container':
module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
else:
module_types = (nn.Sequential,)
for name, module in named_modules:
if depth and isinstance(module, module_types):
yield from flatten_modules(
module.named_children(),
depth - 1,
prefix=(name,) if prefix_is_tuple else name,
module_types=module_types,
)
else:
if prefix_is_tuple:
name = prefix + (name,)
yield name, module
else:
if prefix:
name = '.'.join([prefix, name])
yield name, module
def checkpoint_seq(
functions,
x,
every=1,
flatten=False,
skip_last=False,
preserve_rng_state=True
):
r"""A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order
(sequentially). Therefore, we can divide such a sequence into segments
and checkpoint each segment. All segments except run in :func:`torch.no_grad`
manner, i.e., not storing the intermediate activations. The inputs of each
checkpointed segment will be saved for re-running the segment in the backward pass.
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
.. warning::
Checkpointing currently only supports :func:`torch.autograd.backward`
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
is not supported.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
Args:
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
x: A Tensor that is input to :attr:`functions`
every: checkpoint every-n functions (default: 1)
flatten (bool): flatten nn.Sequential of nn.Sequentials
skip_last (bool): skip checkpointing the last function in the sequence if True
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
Example:
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_seq(model, input_var, every=2)
"""
def run_function(start, end, functions):
def forward(_x):
for j in range(start, end + 1):
_x = functions[j](_x)
return _x
return forward
if isinstance(functions, torch.nn.Sequential):
functions = functions.children()
if flatten:
functions = chain.from_iterable(functions)
if not isinstance(functions, (tuple, list)):
functions = tuple(functions)
num_checkpointed = len(functions)
if skip_last:
num_checkpointed -= 1
end = -1
for start in range(0, num_checkpointed, every):
end = min(start + every - 1, num_checkpointed - 1)
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
if skip_last:
return run_function(end + 1, len(functions) - 1, functions)(x)
return x
def adapt_input_conv(in_chans, conv_weight):
conv_type = conv_weight.dtype
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape
if in_chans == 1:
if I > 3:
assert conv_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
conv_weight = conv_weight.sum(dim=2, keepdim=False)
else:
conv_weight = conv_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:
if I != 3:
raise NotImplementedError('Weight format not supported by conversion.')
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
repeat = int(math.ceil(in_chans / 3))
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv_weight *= (3 / float(in_chans))
conv_weight = conv_weight.to(conv_type)
return conv_weight

@ -0,0 +1,111 @@
import os
from copy import deepcopy
from torch import nn as nn
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
def extract_layer(model, layer):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
if not hasattr(model, 'module') and layer[0] == 'module':
layer = layer[1:]
for l in layer:
if hasattr(module, l):
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
else:
return module
return module
def set_layer(model, layer, val):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
lst_index = 0
module2 = module
for l in layer:
if hasattr(module2, l):
if not l.isdigit():
module2 = getattr(module2, l)
else:
module2 = module2[int(l)]
lst_index += 1
lst_index -= 1
for l in layer[:lst_index]:
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
l = layer[lst_index]
setattr(module, l, val)
def adapt_model_from_string(parent_module, model_string):
separator = '***'
state_dict = {}
lst_shape = model_string.split(separator)
for k in lst_shape:
k = k.split(':')
key = k[0]
shape = k[1][1:-1].split(',')
if shape[0] != '':
state_dict[key] = [int(i) for i in shape]
new_module = deepcopy(parent_module)
for n, m in parent_module.named_modules():
old_module = extract_layer(parent_module, n)
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
if isinstance(old_module, Conv2dSame):
conv = Conv2dSame
else:
conv = nn.Conv2d
s = state_dict[n + '.weight']
in_channels = s[1]
out_channels = s[0]
g = 1
if old_module.groups > 1:
in_channels = out_channels
g = in_channels
new_conv = conv(
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
groups=g, stride=old_module.stride)
set_layer(new_module, n, new_conv)
elif isinstance(old_module, BatchNormAct2d):
new_bn = BatchNormAct2d(
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
new_bn.drop = old_module.drop
new_bn.act = old_module.act
set_layer(new_module, n, new_bn)
elif isinstance(old_module, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d(
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
set_layer(new_module, n, new_bn)
elif isinstance(old_module, nn.Linear):
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
num_features = state_dict[n + '.weight'][1]
new_fc = Linear(
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
set_layer(new_module, n, new_fc)
if hasattr(new_module, 'num_features'):
new_module.num_features = num_features
new_module.eval()
parent_module.eval()
return new_module
def adapt_model_from_file(parent_module, model_variant):
adapt_file = os.path.join(os.path.dirname(__file__), '_pruned', model_variant + '.txt')
with open(adapt_file, 'r') as f:
return adapt_model_from_string(parent_module, f.read().strip())

@ -9,7 +9,7 @@ from collections import defaultdict, deque
from copy import deepcopy
from typing import List, Optional, Union, Tuple
from .pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
__all__ = [
'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
@ -167,10 +167,12 @@ def is_model(model_name):
return arch_name in _model_entrypoints
def model_entrypoint(model_name):
def model_entrypoint(model_name, module_filter: Optional[str] = None):
"""Fetch a model entrypoint for specified model name
"""
arch_name = get_arch_name(model_name)
if module_filter and arch_name not in _module_to_models.get(module_filter, {}):
raise RuntimeError(f'Model ({model_name} not found in module {module_filter}.')
return _model_entrypoints[arch_name]

@ -46,12 +46,13 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from ._builder import build_model_with_cfg
from ._registry import register_model
from .vision_transformer import checkpoint_filter_fn
__all__ = ['Beit']
def _cfg(url='', **kwargs):
return {

@ -13,9 +13,9 @@ Consider all of the models definitions here as experimental WIP and likely to ch
Hacked together by / copyright Ross Wightman, 2021.
"""
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._builder import build_model_with_cfg
from ._registry import register_model
from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
from .helpers import build_model_with_cfg
from .registry import register_model
__all__ = []

@ -26,18 +26,18 @@ Hacked together by / copyright Ross Wightman, 2021.
"""
import math
from dataclasses import dataclass, field, replace
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
from functools import partial
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply, checkpoint_seq
from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\
EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d
from .registry import register_model
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']

@ -8,17 +8,16 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from copy import deepcopy
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from .registry import register_model
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model
__all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn']

@ -7,7 +7,6 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
Modified from timm/models/vision_transformer.py
"""
from copy import deepcopy
from functools import partial
from typing import Tuple, List, Union
@ -16,19 +15,11 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model
from .layers import _assert
__all__ = [
"coat_tiny",
"coat_mini",
"coat_lite_tiny",
"coat_lite_mini",
"coat_lite_small"
]
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
from ._builder import build_model_with_cfg
from ._registry import register_model
__all__ = ['CoaT']
def _cfg_coat(url='', **kwargs):

@ -22,20 +22,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
'''
from functools import partial
import torch
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
from .registry import register_model
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._registry import register_model
from .vision_transformer_hybrid import HybridEmbed
from .fx_features import register_notrace_module
import torch
import torch.nn as nn
__all__ = ['ConViT']
def _cfg(url='', **kwargs):

@ -5,9 +5,12 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.registry import register_model
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import SelectAdaptivePool2d
from timm.layers import SelectAdaptivePool2d
from ._registry import register_model
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
__all__ = ['ConvMixer']
def _cfg(url='', **kwargs):

@ -18,12 +18,12 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
create_conv2d, get_act_layer, make_divisible, to_ntuple
from .pretrained import generate_default_cfgs
from .registry import register_model
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this

@ -24,21 +24,22 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.hub
from functools import partial
from typing import List
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, _assert
from .registry import register_model
from .vision_transformer import Mlp, Block
from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._registry import register_model
from .vision_transformer import Block
__all__ = ['CrossViT'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):

@ -12,20 +12,18 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
Hacked together by / Copyright 2020 Ross Wightman
"""
import collections.abc
from dataclasses import dataclass, field, asdict
from dataclasses import dataclass, asdict
from functools import partial
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
from .registry import register_model
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, MATCH_PREV_GROUP
from ._registry import register_model
__all__ = ['CspNet'] # model_registry will add each entrypoint fn to this

@ -17,9 +17,11 @@ from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model
from .helpers import build_model_with_cfg, checkpoint_seq
from .registry import register_model
__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):

@ -4,7 +4,6 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool.
"""
import re
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
@ -13,9 +12,10 @@ import torch.utils.checkpoint as cp
from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, MATCH_PREV_GROUP
from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
from .registry import register_model
from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
from ._builder import build_model_with_cfg
from ._manipulate import MATCH_PREV_GROUP
from ._registry import register_model
__all__ = ['DenseNet']

@ -13,9 +13,9 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import create_classifier
from .registry import register_model
from timm.layers import create_classifier
from ._builder import build_model_with_cfg
from ._registry import register_model
__all__ = ['DLA']

@ -15,9 +15,9 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier
from .registry import register_model
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier
from ._builder import build_model_with_cfg
from ._registry import register_model
__all__ = ['DPN']

@ -8,20 +8,20 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt
Modifications and additions for timm by / Copyright 2022, Ross Wightman
"""
import math
import torch
from collections import OrderedDict
from functools import partial
from typing import Tuple
from torch import nn
import torch
import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .registry import register_model
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model
__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this

@ -18,9 +18,11 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import DropPath, trunc_normal_, to_2tuple, Mlp
from .registry import register_model
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
from ._builder import build_model_with_cfg
from ._registry import register_model
__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):

@ -42,15 +42,15 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq
from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct
from .registry import register_model
from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq
from ._registry import register_model
__all__ = ['EfficientNet', 'EfficientNetFeatures']

@ -28,12 +28,13 @@ import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply
from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
get_attn, get_act_layer, get_norm_layer, _assert
from .registry import register_model
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._registry import register_model
from .vision_transformer_relpos import RelPosBias # FIXME move to common location
__all__ = ['GlobalContextVit']

@ -11,13 +11,12 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import SelectAdaptivePool2d, Linear, make_divisible
from .efficientnet_blocks import SqueezeExcite, ConvBnAct
from .helpers import build_model_with_cfg, checkpoint_seq
from .registry import register_model
from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
from ._builder import build_model_with_cfg
from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
from ._manipulate import checkpoint_seq
from ._registry import register_model
__all__ = ['GhostNet']

@ -5,11 +5,13 @@ by Ross Wightman
"""
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SEModule
from .registry import register_model
from timm.layers import SEModule
from ._builder import build_model_with_cfg
from ._registry import register_model
from .resnet import ResNet, Bottleneck, BasicBlock
__all__ = []
def _cfg(url='', **kwargs):
return {

@ -13,9 +13,9 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import create_classifier, get_padding
from .registry import register_model
from timm.layers import create_classifier, get_padding
from ._builder import build_model_with_cfg
from ._registry import register_model
__all__ = ['Xception65']

@ -3,12 +3,14 @@ from functools import partial
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .layers import get_act_fn
from ._builder import build_model_with_cfg
from ._builder import pretrained_cfg_for_features
from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
from ._registry import register_model
from .mobilenetv3 import MobileNetV3, MobileNetV3Features
from .registry import register_model
__all__ = [] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save