From eb7653614f438d1eeae259262fade32d230a5be4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 1 Jun 2020 16:59:51 -0700 Subject: [PATCH] Monster commit, activation refactor, VoVNet, norm_act improvements, more * refactor activations into basic PyTorch, jit scripted, and memory efficient custom auto * implement hard-mish, better grad for hard-swish * add initial VovNet V1/V2 impl, fix #151 * VovNet and DenseNet first models to use NormAct layers (support BatchNormAct2d, EvoNorm, InplaceIABN) * Wrap IABN for any models that use it * make more models torchscript compatible (DPN, PNasNet, Res2Net, SelecSLS) and add tests --- tests/test_models.py | 26 +- timm/__init__.py | 3 +- timm/models/__init__.py | 2 + timm/models/densenet.py | 36 +-- timm/models/dpn.py | 64 +++- timm/models/efficientnet.py | 18 +- timm/models/efficientnet_blocks.py | 13 +- timm/models/efficientnet_builder.py | 16 +- timm/models/layers/__init__.py | 39 +-- timm/models/layers/activations.py | 109 ++----- timm/models/layers/activations_jit.py | 90 ++++++ timm/models/layers/activations_me.py | 208 +++++++++++++ timm/models/layers/cond_conv2d.py | 2 +- timm/models/layers/config.py | 74 +++++ timm/models/layers/conv2d_same.py | 3 +- timm/models/layers/conv_bn_act.py | 27 +- timm/models/layers/create_act.py | 103 +++++++ timm/models/layers/create_attn.py | 4 +- timm/models/layers/create_conv2d.py | 12 +- timm/models/layers/create_norm_act.py | 77 +++-- timm/models/layers/drop.py | 2 - timm/models/layers/evo_norm.py | 40 ++- timm/models/layers/inplace_abn.py | 85 ++++++ timm/models/layers/norm_act.py | 69 +++-- timm/models/layers/pool2d_same.py | 1 - timm/models/layers/se.py | 23 +- timm/models/layers/selective_kernel.py | 1 - timm/models/layers/separable_conv.py | 51 ++++ timm/models/layers/test_time_pool.py | 1 + timm/models/mobilenetv3.py | 20 +- timm/models/pnasnet.py | 24 +- timm/models/res2net.py | 14 +- timm/models/resnet.py | 1 - timm/models/selecsls.py | 26 +- timm/models/tresnet.py | 81 +++-- timm/models/vovnet.py | 408 +++++++++++++++++++++++++ validate.py | 10 +- 37 files changed, 1467 insertions(+), 316 deletions(-) create mode 100644 timm/models/layers/activations_jit.py create mode 100644 timm/models/layers/activations_me.py create mode 100644 timm/models/layers/config.py create mode 100644 timm/models/layers/create_act.py create mode 100644 timm/models/layers/inplace_abn.py create mode 100644 timm/models/layers/separable_conv.py create mode 100644 timm/models/vovnet.py diff --git a/tests/test_models.py b/tests/test_models.py index 02cb61bb..63be6a6e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,7 +4,7 @@ import platform import os import fnmatch -from timm import list_models, create_model +from timm import list_models, create_model, set_scriptable if 'GITHUB_ACTIONS' in os.environ and 'Linux' in platform.system(): @@ -53,6 +53,8 @@ def test_model_backward(model_name, batch_size): inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) outputs.mean().backward() + for n, x in model.named_parameters(): + assert x.grad is not None, f'No gradient for {n}' num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) assert outputs.shape[-1] == 42 @@ -83,3 +85,25 @@ def test_model_default_cfgs(model_name, batch_size): assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params' assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params' + + +EXCLUDE_JIT_FILTERS = [ + '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable + 'dla*', 'hrnet*', # hopefully fix at some point +] + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_torchscript(model_name, batch_size): + """Run a single forward pass with each model""" + with set_scriptable(True): + model = create_model(model_name, pretrained=False) + model.eval() + input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already... + model = torch.jit.script(model) + outputs = model(torch.randn((batch_size, *input_size))) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' diff --git a/timm/__init__.py b/timm/__init__.py index 86ed7a42..db3d3f22 100644 --- a/timm/__init__.py +++ b/timm/__init__.py @@ -1,2 +1,3 @@ from .version import __version__ -from .models import create_model, list_models, is_model, list_modules, model_entrypoint +from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ + is_scriptable, is_exportable, set_scriptable, set_exportable diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 06d26fb3..b4fe1dea 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -20,9 +20,11 @@ from .sknet import * from .tresnet import * from .resnest import * from .regnet import * +from .vovnet import * from .registry import * from .factory import create_model from .helpers import load_checkpoint, resume_checkpoint from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model +from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 539d5012..b4e31807 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -41,13 +41,13 @@ default_cfgs = { class DenseLayer(nn.Module): - def __init__(self, num_input_features, growth_rate, bn_size, norm_act_layer=BatchNormAct2d, + def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d, drop_rate=0., memory_efficient=False): super(DenseLayer, self).__init__() - self.add_module('norm1', norm_act_layer(num_input_features)), + self.add_module('norm1', norm_layer(num_input_features)), self.add_module('conv1', nn.Conv2d( num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), - self.add_module('norm2', norm_act_layer(bn_size * growth_rate)), + self.add_module('norm2', norm_layer(bn_size * growth_rate)), self.add_module('conv2', nn.Conv2d( bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), self.drop_rate = float(drop_rate) @@ -109,7 +109,7 @@ class DenseLayer(nn.Module): class DenseBlock(nn.ModuleDict): _version = 2 - def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_act_layer=nn.ReLU, + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, drop_rate=0., memory_efficient=False): super(DenseBlock, self).__init__() for i in range(num_layers): @@ -117,7 +117,7 @@ class DenseBlock(nn.ModuleDict): num_input_features + i * growth_rate, growth_rate=growth_rate, bn_size=bn_size, - norm_act_layer=norm_act_layer, + norm_layer=norm_layer, drop_rate=drop_rate, memory_efficient=memory_efficient, ) @@ -132,9 +132,9 @@ class DenseBlock(nn.ModuleDict): class DenseTransition(nn.Sequential): - def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d, aa_layer=None): + def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None): super(DenseTransition, self).__init__() - self.add_module('norm', norm_act_layer(num_input_features)) + self.add_module('norm', norm_layer(num_input_features)) self.add_module('conv', nn.Conv2d( num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) if aa_layer is not None: @@ -160,7 +160,7 @@ class DenseNet(nn.Module): def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='', num_classes=1000, in_chans=3, global_pool='avg', - norm_act_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False): + norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False): self.num_classes = num_classes self.drop_rate = drop_rate super(DenseNet, self).__init__() @@ -181,17 +181,17 @@ class DenseNet(nn.Module): stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4) self.features = nn.Sequential(OrderedDict([ ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)), - ('norm0', norm_act_layer(stem_chs_1)), + ('norm0', norm_layer(stem_chs_1)), ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)), - ('norm1', norm_act_layer(stem_chs_2)), + ('norm1', norm_layer(stem_chs_2)), ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)), - ('norm2', norm_act_layer(num_init_features)), + ('norm2', norm_layer(num_init_features)), ('pool0', stem_pool), ])) else: self.features = nn.Sequential(OrderedDict([ ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), - ('norm0', norm_act_layer(num_init_features)), + ('norm0', norm_layer(num_init_features)), ('pool0', stem_pool), ])) @@ -203,7 +203,7 @@ class DenseNet(nn.Module): num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, - norm_act_layer=norm_act_layer, + norm_layer=norm_layer, drop_rate=drop_rate, memory_efficient=memory_efficient ) @@ -212,12 +212,12 @@ class DenseNet(nn.Module): if i != len(block_config) - 1: trans = DenseTransition( num_input_features=num_features, num_output_features=num_features // 2, - norm_act_layer=norm_act_layer) + norm_layer=norm_layer) self.features.add_module('transition%d' % (i + 1), trans) num_features = num_features // 2 # Final batch norm - self.features.add_module('norm5', norm_act_layer(num_features)) + self.features.add_module('norm5', norm_layer(num_features)) # Linear layer self.num_features = num_features @@ -346,7 +346,7 @@ def densenet121d_evob(pretrained=False, **kwargs): return create_norm_act('EvoNormBatch', num_features, jit=True, **kwargs) model = _densenet( 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', - norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs) + norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) return model @@ -359,7 +359,7 @@ def densenet121d_evos(pretrained=False, **kwargs): return create_norm_act('EvoNormSample', num_features, jit=True, **kwargs) model = _densenet( 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', - norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs) + norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) return model @@ -372,7 +372,7 @@ def densenet121d_iabn(pretrained=False, **kwargs): return create_norm_act('iabn', num_features, **kwargs) model = _densenet( 'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', - norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs) + norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) return model diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 9c4fafc8..1f45095d 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -10,6 +10,7 @@ from __future__ import division from __future__ import print_function from collections import OrderedDict +from typing import Union, Optional, List, Tuple import torch import torch.nn as nn @@ -54,8 +55,19 @@ class CatBnAct(nn.Module): self.bn = nn.BatchNorm2d(in_chs, eps=0.001) self.act = activation_fn + @torch.jit._overload_method # noqa: F811 def forward(self, x): - x = torch.cat(x, dim=1) if isinstance(x, tuple) else x + # type: (Tuple[torch.Tensor, torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (torch.Tensor) + pass + + def forward(self, x): + if isinstance(x, tuple): + x = torch.cat(x, dim=1) return self.act(self.bn(x)) @@ -107,6 +119,8 @@ class DualPathBlock(nn.Module): self.key_stride = 1 self.has_proj = False + self.c1x1_w_s1 = None + self.c1x1_w_s2 = None if self.has_proj: # Using different member names here to allow easier parameter key matching for conversion if self.key_stride == 2: @@ -115,6 +129,7 @@ class DualPathBlock(nn.Module): else: self.c1x1_w_s1 = BnActConv2d( in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1) + self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1) self.c3x3_b = BnActConv2d( in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3, @@ -125,27 +140,46 @@ class DualPathBlock(nn.Module): self.c1x1_c2 = nn.Conv2d(num_3x3_b, inc, kernel_size=1, bias=False) else: self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1) + self.c1x1_c1 = None + self.c1x1_c2 = None + @torch.jit._overload_method # noqa: F811 def forward(self, x): - x_in = torch.cat(x, dim=1) if isinstance(x, tuple) else x - if self.has_proj: - if self.key_stride == 2: - x_s = self.c1x1_w_s2(x_in) - else: - x_s = self.c1x1_w_s1(x_in) - x_s1 = x_s[:, :self.num_1x1_c, :, :] - x_s2 = x_s[:, self.num_1x1_c:, :, :] + # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] + pass + + def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(x, tuple): + x_in = torch.cat(x, dim=1) else: + x_in = x + if self.c1x1_w_s1 is None and self.c1x1_w_s2 is None: + # self.has_proj == False, torchscript requires condition on module == None x_s1 = x[0] x_s2 = x[1] + else: + # self.has_proj == True + if self.c1x1_w_s1 is not None: + # self.key_stride = 1 + x_s = self.c1x1_w_s1(x_in) + else: + # self.key_stride = 2 + x_s = self.c1x1_w_s2(x_in) + x_s1 = x_s[:, :self.num_1x1_c, :, :] + x_s2 = x_s[:, self.num_1x1_c:, :, :] x_in = self.c1x1_a(x_in) x_in = self.c3x3_b(x_in) - if self.b: - x_in = self.c1x1_c(x_in) + x_in = self.c1x1_c(x_in) + if self.c1x1_c1 is not None: + # self.b == True, using None check for torchscript compat out1 = self.c1x1_c1(x_in) out2 = self.c1x1_c2(x_in) else: - x_in = self.c1x1_c(x_in) out1 = x_in[:, :self.num_1x1_c, :, :] out2 = x_in[:, self.num_1x1_c:, :, :] resid = x_s1 + out1 @@ -167,11 +201,9 @@ class DPN(nn.Module): # conv1 if small: - blocks['conv1_1'] = InputBlock( - num_init_features, in_chans=in_chans, kernel_size=3, padding=1) + blocks['conv1_1'] = InputBlock(num_init_features, in_chans=in_chans, kernel_size=3, padding=1) else: - blocks['conv1_1'] = InputBlock( - num_init_features, in_chans=in_chans, kernel_size=7, padding=3) + blocks['conv1_1'] = InputBlock(num_init_features, in_chans=in_chans, kernel_size=7, padding=3) # conv2 bw = 64 * bw_factor diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 21fbee19..fbd7f420 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -24,11 +24,15 @@ An implementation of EfficienNet that covers variety of related models with effi Hacked together by Ross Wightman """ +import torch.nn as nn +import torch.nn.functional as F + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_builder import * +from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT +from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .feature_hooks import FeatureHooks from .helpers import load_pretrained, adapt_model_from_file -from .layers import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d, create_conv2d from .registry import register_model __all__ = ['EfficientNet'] @@ -631,7 +635,7 @@ def _gen_mobilenet_v2( fix_stem=fix_stem_head, channel_multiplier=channel_multiplier, norm_kwargs=resolve_bn_args(kwargs), - act_layer=nn.ReLU6, + act_layer=resolve_act_layer(kwargs, 'relu6'), **kwargs ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -741,7 +745,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, - act_layer=Swish, + act_layer=resolve_act_layer(kwargs, 'swish'), norm_kwargs=resolve_bn_args(kwargs), variant=variant, **kwargs, @@ -772,7 +776,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0 stem_size=32, channel_multiplier=channel_multiplier, norm_kwargs=resolve_bn_args(kwargs), - act_layer=nn.ReLU, + act_layer=resolve_act_layer(kwargs, 'relu'), **kwargs, ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -802,7 +806,7 @@ def _gen_efficientnet_condconv( stem_size=32, channel_multiplier=channel_multiplier, norm_kwargs=resolve_bn_args(kwargs), - act_layer=Swish, + act_layer=resolve_act_layer(kwargs, 'swish'), **kwargs, ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -842,7 +846,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0 stem_size=32, fix_stem=True, channel_multiplier=channel_multiplier, - act_layer=nn.ReLU6, + act_layer=resolve_act_layer(kwargs, 'relu6'), norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index cc4cdef1..5f64dc37 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn from torch.nn import functional as F -from .layers.activations import sigmoid -from .layers import create_conv2d, drop_path +from .layers import create_conv2d, drop_path, get_act_layer +from .layers.activations import sigmoid # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) @@ -52,6 +52,13 @@ def resolve_se_args(kwargs, in_chs, act_layer=None): return se_kwargs +def resolve_act_layer(kwargs, default='relu'): + act_layer = kwargs.pop('act_layer', default) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + return act_layer + + def make_divisible(v, divisor=8, min_value=None): min_value = min_value or divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) @@ -213,7 +220,7 @@ class InvertedResidual(nn.Module): has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate - + print(act_layer) # Point-wise expansion self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) self.bn1 = norm_layer(mid_chs, **norm_kwargs) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 842098cf..1e06b4f3 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -1,13 +1,15 @@ import logging import math import re -from collections.__init__ import OrderedDict +from collections import OrderedDict from copy import deepcopy import torch.nn as nn -from .layers import CondConv2d, get_condconv_initializer -from .layers.activations import HardSwish, Swish + from .efficientnet_blocks import * +from .layers import CondConv2d, get_condconv_initializer + +__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"] def _parse_ksize(ss): @@ -57,13 +59,13 @@ def _decode_block_str(block_str): key = op[0] v = op[1:] if v == 're': - value = nn.ReLU + value = get_act_layer('relu') elif v == 'r6': - value = nn.ReLU6 + value = get_act_layer('relu6') elif v == 'hs': - value = HardSwish + value = get_act_layer('hard_swish') elif v == 'sw': - value = Swish + value = get_act_layer('swish') else: continue options[key] = value diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index e007a46d..b9c26fea 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1,25 +1,28 @@ -from .padding import get_padding -from .pool2d_same import AvgPool2dSame -from .conv2d_same import Conv2dSame -from .conv_bn_act import ConvBnAct -from .mixed_conv2d import MixedConv2d -from .cond_conv2d import CondConv2d, get_condconv_initializer -from .pool2d_same import create_pool2d -from .create_conv2d import create_conv2d -from .create_attn import create_attn -from .selective_kernel import SelectiveKernelConv -from .se import SEModule -from .eca import EcaModule, CecaModule from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d -from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path -from .test_time_pool import TestTimePoolHead, apply_test_time_pool -from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .anti_aliasing import AntiAliasDownsampleLayer -from .space_to_depth import SpaceToDepthModule from .blur_pool import BlurPool2d -from .norm_act import BatchNormAct2d +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .config import is_exportable, is_scriptable, set_exportable, set_scriptable, is_no_jit, set_no_jit +from .conv2d_same import Conv2dSame +from .conv_bn_act import ConvBnAct +from .create_act import create_act_layer, get_act_layer, get_act_fn +from .create_attn import create_attn +from .create_conv2d import create_conv2d +from .create_norm_act import create_norm_act, get_norm_act_layer +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .eca import EcaModule, CecaModule from .evo_norm import EvoNormBatch2d, EvoNormSample2d -from .create_norm_act import create_norm_act +from .inplace_abn import InplaceAbn +from .mixed_conv2d import MixedConv2d +from .norm_act import BatchNormAct2d +from .padding import get_padding +from .pool2d_same import AvgPool2dSame, create_pool2d +from .se import SEModule +from .selective_kernel import SelectiveKernelConv +from .separable_conv import SeparableConv2d, SeparableConvBnAct +from .space_to_depth import SpaceToDepthModule +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .weight_init import trunc_normal_ diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py index 6f8d2f89..71904935 100644 --- a/timm/models/layers/activations.py +++ b/timm/models/layers/activations.py @@ -6,85 +6,15 @@ easily be swapped. All have an `inplace` arg even if not used. Hacked together by Ross Wightman """ - import torch from torch import nn as nn from torch.nn import functional as F -_USE_MEM_EFFICIENT_ISH = True -if _USE_MEM_EFFICIENT_ISH: - # This version reduces memory overhead of Swish during training by - # recomputing torch.sigmoid(x) in backward instead of saving it. - @torch.jit.script - def swish_jit_fwd(x): - return x.mul(torch.sigmoid(x)) - - - @torch.jit.script - def swish_jit_bwd(x, grad_output): - x_sigmoid = torch.sigmoid(x) - return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) - - - class SwishJitAutoFn(torch.autograd.Function): - """ torch.jit.script optimised Swish - Inspired by conversation btw Jeremy Howard & Adam Pazske - https://twitter.com/jeremyphoward/status/1188251041835315200 - """ - - @staticmethod - def forward(ctx, x): - ctx.save_for_backward(x) - return swish_jit_fwd(x) - - @staticmethod - def backward(ctx, grad_output): - x = ctx.saved_tensors[0] - return swish_jit_bwd(x, grad_output) - - - def swish(x, _inplace=False): - return SwishJitAutoFn.apply(x) - - - @torch.jit.script - def mish_jit_fwd(x): - return x.mul(torch.tanh(F.softplus(x))) - - - @torch.jit.script - def mish_jit_bwd(x, grad_output): - x_sigmoid = torch.sigmoid(x) - x_tanh_sp = F.softplus(x).tanh() - return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) - - - class MishJitAutoFn(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - ctx.save_for_backward(x) - return mish_jit_fwd(x) - - @staticmethod - def backward(ctx, grad_output): - x = ctx.saved_tensors[0] - return mish_jit_bwd(x, grad_output) - - def mish(x, _inplace=False): - return MishJitAutoFn.apply(x) - -else: - def swish(x, inplace: bool = False): - """Swish - Described in: https://arxiv.org/abs/1710.05941 - """ - return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) - - - def mish(x, _inplace: bool = False): - """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 - """ - return x.mul(F.softplus(x).tanh()) +def swish(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) class Swish(nn.Module): @@ -96,13 +26,21 @@ class Swish(nn.Module): return swish(x, self.inplace) +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + NOTE: I don't have a working inplace variant + """ + return x.mul(F.softplus(x).tanh()) + + class Mish(nn.Module): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ def __init__(self, inplace: bool = False): super(Mish, self).__init__() - self.inplace = inplace def forward(self, x): - return mish(x, self.inplace) + return mish(x) def sigmoid(x, inplace: bool = False): @@ -162,3 +100,22 @@ class HardSigmoid(nn.Module): def forward(self, x): return hard_sigmoid(x, self.inplace) + +def hard_mish(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + if inplace: + return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) + else: + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_mish(x, self.inplace) diff --git a/timm/models/layers/activations_jit.py b/timm/models/layers/activations_jit.py new file mode 100644 index 00000000..dd3277fa --- /dev/null +++ b/timm/models/layers/activations_jit.py @@ -0,0 +1,90 @@ +""" Activations + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Hacked together by Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) + + +@torch.jit.script +def hard_mish_jit(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishJit, self).__init__() + + def forward(self, x): + return hard_mish_jit(x) diff --git a/timm/models/layers/activations_me.py b/timm/models/layers/activations_me.py new file mode 100644 index 00000000..9c492f1e --- /dev/null +++ b/timm/models/layers/activations_me.py @@ -0,0 +1,208 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Hacked together by Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_mish_jit_fwd(x): + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +@torch.jit.script +def hard_mish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= -2.) + m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) + return grad_output * m + + +class HardMishJitAutoFn(torch.autograd.Function): + """ A memory efficient, jit scripted variant of Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def hard_mish_me(x, inplace: bool = False): + return HardMishJitAutoFn.apply(x) + + +class HardMishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishMe, self).__init__() + + def forward(self, x): + return HardMishJitAutoFn.apply(x) + + + diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py index 0241b501..b1759d99 100644 --- a/timm/models/layers/cond_conv2d.py +++ b/timm/models/layers/cond_conv2d.py @@ -15,7 +15,7 @@ from torch.nn import functional as F from .helpers import tup_pair from .conv2d_same import conv2d_same -from timm.models.layers.padding import get_padding_value +from .padding import get_padding_value def get_condconv_initializer(initializer, num_experts, expert_shape): diff --git a/timm/models/layers/config.py b/timm/models/layers/config.py new file mode 100644 index 00000000..2c0faf23 --- /dev/null +++ b/timm/models/layers/config.py @@ -0,0 +1,74 @@ +""" Model / Layer Config Singleton +""" +from typing import Any + +__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable', 'is_no_jit', 'set_no_jit'] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False diff --git a/timm/models/layers/conv2d_same.py b/timm/models/layers/conv2d_same.py index 863d1783..06f08b4e 100644 --- a/timm/models/layers/conv2d_same.py +++ b/timm/models/layers/conv2d_same.py @@ -7,8 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional -from timm.models.layers.padding import get_padding_value -from .padding import pad_same +from .padding import pad_same, get_padding_value def conv2d_same( diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py index d7835320..43f6760e 100644 --- a/timm/models/layers/conv_bn_act.py +++ b/timm/models/layers/conv_bn_act.py @@ -4,33 +4,28 @@ Hacked together by Ross Wightman """ from torch import nn as nn -from timm.models.layers import get_padding +from .create_conv2d import create_conv2d +from .create_norm_act import convert_norm_act_type class ConvBnAct(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, - drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, act_layer=nn.ReLU, apply_act=True, + drop_block=None, aa_layer=None): super(ConvBnAct, self).__init__() - padding = get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block use_aa = aa_layer is not None - self.conv = nn.Conv2d( - in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1 if use_aa else stride, + self.conv = create_conv2d( + in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, padding=padding, dilation=dilation, groups=groups, bias=False) - self.bn = norm_layer(out_channels) + + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None - self.drop_block = drop_block - if act_layer is not None: - self.act = act_layer(inplace=True) - else: - self.act = None def forward(self, x): x = self.conv(x) x = self.bn(x) - if self.drop_block is not None: - x = self.drop_block(x) - if self.act is not None: - x = self.act(x) if self.aa is not None: x = self.aa(x) return x diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py new file mode 100644 index 00000000..66ab1e84 --- /dev/null +++ b/timm/models/layers/create_act.py @@ -0,0 +1,103 @@ +from .activations import * +from .activations_jit import * +from .activations_me import * +from .config import is_exportable, is_scriptable, is_no_jit + + +_ACT_FN_DEFAULT = dict( + swish=swish, + mish=mish, + relu=F.relu, + relu6=F.relu6, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=hard_sigmoid, + hard_swish=hard_swish, + hard_mish=hard_mish, +) + +_ACT_FN_JIT = dict( + swish=swish_jit, + mish=mish_jit, + hard_sigmoid=hard_sigmoid_jit, + hard_swish=hard_swish_jit, + hard_mish=hard_mish_jit +) + +_ACT_FN_ME = dict( + swish=swish_me, + mish=mish_me, + hard_sigmoid=hard_sigmoid_me, + hard_swish=hard_swish_me, + hard_mish=hard_mish_me, +) + +_ACT_LAYER_DEFAULT = dict( + swish=Swish, + mish=Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=HardSigmoid, + hard_swish=HardSwish, + hard_mish=HardMish, +) + +_ACT_LAYER_JIT = dict( + swish=SwishJit, + mish=MishJit, + hard_sigmoid=HardSigmoidJit, + hard_swish=HardSwishJit, + hard_mish=HardMishJit +) + +_ACT_LAYER_ME = dict( + swish=SwishMe, + mish=MishMe, + hard_sigmoid=HardSigmoidMe, + hard_swish=HardSwishMe, + hard_mish=HardMishMe, +) + + +def get_act_fn(name='relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if not (is_no_jit() or is_exportable() or is_scriptable()): + # If not exporting or scripting the model, first look for a memory-efficient version with + # custom autograd, then fallback + if name in _ACT_FN_ME: + return _ACT_FN_ME[name] + if not is_no_jit(): + if name in _ACT_FN_JIT: + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name='relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if not (is_no_jit() or is_exportable() or is_scriptable()): + if name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if not is_no_jit(): + if name in _ACT_LAYER_JIT: + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + +def create_act_layer(name, inplace=False, **kwargs): + act_layer = get_act_layer(name) + if act_layer is not None: + return act_layer(inplace=inplace, **kwargs) + else: + return None diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index 94c4e4e7..24eccaa0 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -3,7 +3,7 @@ Hacked together by Ross Wightman """ import torch -from .se import SEModule +from .se import SEModule, EffectiveSEModule from .eca import EcaModule, CecaModule from .cbam import CbamModule, LightCbamModule @@ -15,6 +15,8 @@ def create_attn(attn_type, channels, **kwargs): attn_type = attn_type.lower() if attn_type == 'se': module_cls = SEModule + elif attn_type == 'ese': + module_cls = EffectiveSEModule elif attn_type == 'eca': module_cls = EcaModule elif attn_type == 'ceca': diff --git a/timm/models/layers/create_conv2d.py b/timm/models/layers/create_conv2d.py index 527c80a3..34fbd44f 100644 --- a/timm/models/layers/create_conv2d.py +++ b/timm/models/layers/create_conv2d.py @@ -8,23 +8,23 @@ from .cond_conv2d import CondConv2d from .conv2d_same import create_conv2d_pad -def create_conv2d(in_chs, out_chs, kernel_size, **kwargs): +def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): """ Select a 2d convolution implementation based on arguments Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. Used extensively by EfficientNet, MobileNetv3 and related networks. """ - assert 'groups' not in kwargs # only use 'depthwise' bool arg if isinstance(kernel_size, list): assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + assert 'groups' not in kwargs # MixedConv groups are defined by kernel list # We're going to use only lists for defining the MixedConv2d kernel groups, # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. - m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) else: depthwise = kwargs.pop('depthwise', False) - groups = out_chs if depthwise else 1 + groups = out_channels if depthwise else kwargs.pop('groups', 1) if 'num_experts' in kwargs and kwargs['num_experts'] > 0: - m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) else: - m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) return m diff --git a/timm/models/layers/create_norm_act.py b/timm/models/layers/create_norm_act.py index 251c0c17..7bdaa125 100644 --- a/timm/models/layers/create_norm_act.py +++ b/timm/models/layers/create_norm_act.py @@ -1,37 +1,64 @@ +import types +import functools + import torch import torch.nn as nn from .evo_norm import EvoNormBatch2d, EvoNormSample2d -from .norm_act import BatchNormAct2d -try: - from inplace_abn import InPlaceABN - has_iabn = True -except ImportError: - has_iabn = False +from .norm_act import BatchNormAct2d, GroupNormAct +from .inplace_abn import InplaceAbn +_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} -def create_norm_act(layer_type, num_features, jit=False, **kwargs): - layer_parts = layer_type.split('_') - assert len(layer_parts) in (1, 2) - layer_class = layer_parts[0].lower() - #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection - - if layer_class == "batchnormact": - layer = BatchNormAct2d(num_features, **kwargs) # defaults to RELU of no kwargs override - elif layer_class == "batchnormrelu": - assert 'act_layer' not in kwargs - layer = BatchNormAct2d(num_features, act_layer=nn.ReLU, **kwargs) + +def get_norm_act_layer(layer_class): + layer_class = layer_class.replace('_', '').lower() + if layer_class.startswith("batchnorm"): + layer = BatchNormAct2d + elif layer_class.startswith("groupnorm"): + layer = GroupNormAct elif layer_class == "evonormbatch": - layer = EvoNormBatch2d(num_features, **kwargs) + layer = EvoNormBatch2d elif layer_class == "evonormsample": - layer = EvoNormSample2d(num_features, **kwargs) + layer = EvoNormSample2d elif layer_class == "iabn" or layer_class == "inplaceabn": - if not has_iabn: - raise ImportError( - "Pplease install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'") - layer = InPlaceABN(num_features, **kwargs) + layer = InplaceAbn else: assert False, "Invalid norm_act layer (%s)" % layer_class - if jit: - layer = torch.jit.script(layer) return layer + + +def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): + layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu + assert len(layer_parts) in (1, 2) + layer = get_norm_act_layer(layer_parts[0]) + #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + if jit: + layer_instance = torch.jit.script(layer_instance) + return layer_instance + + +def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) + norm_act_args = norm_kwargs.copy() if norm_kwargs else {} + if isinstance(norm_layer, str): + norm_act_layer = get_norm_act_layer(norm_layer) + elif norm_layer in _NORM_ACT_TYPES: + norm_act_layer = norm_layer + elif isinstance(norm_layer, (types.FunctionType, functools.partial)): + # assuming this is a lambda/fn/bound partial that creates norm_act layer + norm_act_layer = norm_layer + else: + type_name = norm_layer.__name__.lower() + if type_name.startswith('batchnorm'): + norm_act_layer = BatchNormAct2d + elif type_name.startswith('groupnorm'): + norm_act_layer = GroupNormAct + else: + assert False, f"No equivalent norm_act layer for {type_name}" + # Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. + # Newer models will use `apply_act` and likely have `act_layer` arg bound to relevant NormAct types. + norm_act_args.update(dict(act_layer=act_layer)) + return norm_act_layer, norm_act_args diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index 5f2008c0..c91b969e 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -17,8 +17,6 @@ Hacked together by Ross Wightman import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np -import math def drop_block_2d( diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 62d49428..c7c00b80 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -2,9 +2,9 @@ An attempt at getting decent performing EvoNorms running in PyTorch. While currently faster than other impl, still quite a ways off the built-in BN -in terms of memory usage and throughput. +in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed). -Still very much a WIP, fiddling with buffer usage, in-place optimizations, and layouts. +Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts. Hacked together by Ross Wightman """ @@ -14,15 +14,15 @@ import torch.nn as nn class EvoNormBatch2d(nn.Module): - def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): super(EvoNormBatch2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum - self.nonlin = nonlin self.eps = eps param_shape = (1, num_features, 1, 1) self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) - if nonlin: + if apply_act: self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) self.reset_parameters() @@ -30,7 +30,7 @@ class EvoNormBatch2d(nn.Module): def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) - if self.nonlin: + if self.apply_act: nn.init.ones_(self.v) def forward(self, x): @@ -40,46 +40,42 @@ class EvoNormBatch2d(nn.Module): var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var) else: - var = self.running_var.clone() + var = self.running_var - if self.nonlin: + if self.apply_act: v = self.v.to(dtype=x_type) - d = (x * v) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type) - d = d.max(var.add_(self.eps).sqrt_().to(dtype=x_type)) + d = (x * v) + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) + d = d.max((var + self.eps).sqrt().to(dtype=x_type)) x = x / d - return x.mul_(self.weight).add_(self.bias) - else: - return x.mul(self.weight).add_(self.bias) + return x * self.weight + self.bias class EvoNormSample2d(nn.Module): - def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5): + def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None): super(EvoNormSample2d, self).__init__() - self.nonlin = nonlin + self.apply_act = apply_act # apply activation (non-linearity) self.groups = groups self.eps = eps param_shape = (1, num_features, 1, 1) self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) - if nonlin: + if apply_act: self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) - if self.nonlin: + if self.apply_act: nn.init.ones_(self.v) def forward(self, x): assert x.dim() == 4, 'expected 4D input' B, C, H, W = x.shape assert C % self.groups == 0 - if self.nonlin: + if self.apply_act: n = (x * self.v).sigmoid().reshape(B, self.groups, -1) x = x.reshape(B, self.groups, -1) - x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(self.eps).sqrt_() + x = n / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() x = x.reshape(B, C, H, W) - return x.mul_(self.weight).add_(self.bias) - else: - return x.mul(self.weight).add_(self.bias) + return x * self.weight + self.bias diff --git a/timm/models/layers/inplace_abn.py b/timm/models/layers/inplace_abn.py new file mode 100644 index 00000000..d78079db --- /dev/null +++ b/timm/models/layers/inplace_abn.py @@ -0,0 +1,85 @@ +import torch +from torch import nn as nn + +try: + from inplace_abn.functions import inplace_abn, inplace_abn_sync + has_iabn = True +except ImportError: + has_iabn = False + + def inplace_abn(x, weight, bias, running_mean, running_var, + training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): + raise ImportError( + "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'") + + def inplace_abn_sync(**kwargs): + inplace_abn(**kwargs) + + +class InplaceAbn(nn.Module): + """Activated Batch Normalization + + This gathers a BatchNorm and an activation function in a single module + + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + act_layer : str or nn.Module type + Name or type of the activation functions, one of: `leaky_relu`, `elu` + act_param : float + Negative slope for the `leaky_relu` activation. + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, + act_layer="leaky_relu", act_param=0.01, drop_block=None,): + super(InplaceAbn, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.momentum = momentum + if apply_act: + if isinstance(act_layer, str): + assert act_layer in ('leaky_relu', 'elu', 'identity') + self.act_name = act_layer + else: + # convert act layer passed as type to string + if isinstance(act_layer, nn.ELU): + self.act_name = 'elu' + elif isinstance(act_layer, nn.LeakyReLU): + self.act_name = 'leaky_relu' + else: + assert False, f'Invalid act layer {act_layer.__name__} for IABN' + else: + self.act_name = 'identity' + self.act_param = act_param + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.running_mean, 0) + nn.init.constant_(self.running_var, 1) + if self.affine: + nn.init.constant_(self.weight, 1) + nn.init.constant_(self.bias, 0) + + def forward(self, x): + output = inplace_abn( + x, self.weight, self.bias, self.running_mean, self.running_var, + self.training, self.momentum, self.eps, self.act_name, self.act_param) + if isinstance(output, tuple): + output = output[0] + return output diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 879a8939..48c4d6da 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -1,28 +1,33 @@ """ Normalization + Activation Layers """ +import torch from torch import nn as nn from torch.nn import functional as F +from .create_act import get_act_layer + class BatchNormAct2d(nn.BatchNorm2d): """BatchNorm + Activation - This module performs BatchNorm + Actibation in s manner that will remain bavkwards + This module performs BatchNorm + Activation in a manner that will remain backwards compatible with weights trained with separate bn, act. This is why we inherit from BN instead of composing it as a .bn member. """ - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, - track_running_stats=True, act_layer=nn.ReLU, inplace=True): - super(BatchNormAct2d, self).__init__(num_features, eps, momentum, affine, track_running_stats) - self.act = act_layer(inplace=inplace) - - def forward(self, x): - # FIXME cannot call parent forward() and maintain jit.script compatibility? - # x = super(BatchNormAct2d, self).forward(x) - - # BEGIN nn.BatchNorm2d forward() cut & paste - # self._check_input_dim(x) + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + if act_layer is not None and apply_act: + self.act = act_layer(inplace=inplace) + else: + self.act = None + def _forward_jit(self, x): + """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function + """ # exponential_average_factor is self.momentum set to # (when it is available) only so that if gets updated # in ONNX graph when this node is exported to ONNX. @@ -41,10 +46,40 @@ class BatchNormAct2d(nn.BatchNorm2d): exponential_average_factor = self.momentum x = F.batch_norm( - x, self.running_mean, self.running_var, self.weight, self.bias, - self.training or not self.track_running_stats, - exponential_average_factor, self.eps) - # END BatchNorm2d forward() + x, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, self.eps) + return x + + @torch.jit.ignore + def _forward_python(self, x): + return super(BatchNormAct2d, self).forward(x) + + def forward(self, x): + # FIXME cannot call parent forward() and maintain jit.script compatibility? + if torch.jit.is_scripting(): + x = self._forward_jit(x) + else: + self._forward_python(x) + if self.act is not None: + x = self.act(x) + return x + - x = self.act(x) +class GroupNormAct(nn.GroupNorm): + + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): + super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + if act_layer is not None and apply_act: + self.act = act_layer(inplace=inplace) + else: + self.act = None + + def forward(self, x): + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + if self.act is not None: + x = self.act(x) return x diff --git a/timm/models/layers/pool2d_same.py b/timm/models/layers/pool2d_same.py index 40f6dacc..7135f831 100644 --- a/timm/models/layers/pool2d_same.py +++ b/timm/models/layers/pool2d_same.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn import torch.nn.functional as F from typing import Union, List, Tuple, Optional -import math from .helpers import tup_pair from .padding import pad_same, get_padding_value diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py index 6bb4723e..83389fc5 100644 --- a/timm/models/layers/se.py +++ b/timm/models/layers/se.py @@ -1,9 +1,11 @@ from torch import nn as nn +from .create_act import get_act_fn class SEModule(nn.Module): - def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None): + def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None, + gate_fn='hard_sigmoid'): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) reduction_channels = reduction_channels or max(channels // reduction, min_channels) @@ -12,10 +14,27 @@ class SEModule(nn.Module): self.act = act_layer(inplace=True) self.fc2 = nn.Conv2d( reduction_channels, channels, kernel_size=1, padding=0, bias=True) + self.gate_fn = get_act_fn(gate_fn) def forward(self, x): x_se = self.avg_pool(x) x_se = self.fc1(x_se) x_se = self.act(x_se) x_se = self.fc2(x_se) - return x * x_se.sigmoid() + return x * self.gate_fn(x_se) + + +class EffectiveSEModule(nn.Module): + """ 'Effective Squeeze-Excitation + From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + """ + def __init__(self, channel, gate_fn='hard_sigmoid'): + super(EffectiveSEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d(channel, channel, kernel_size=1, padding=0) + self.gate_fn = get_act_fn(gate_fn) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.fc(x_se) + return x * self.gate_fn(x_se, inplace=True) diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index ed9132de..e7535f71 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -4,7 +4,6 @@ Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) Hacked together by Ross Wightman """ - import torch from torch import nn as nn diff --git a/timm/models/layers/separable_conv.py b/timm/models/layers/separable_conv.py new file mode 100644 index 00000000..3df0387a --- /dev/null +++ b/timm/models/layers/separable_conv.py @@ -0,0 +1,51 @@ +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import convert_norm_act_type + + +class SeparableConvBnAct(nn.Module): + """ Separable Conv w/ trailing Norm and Activation + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + act_layer=nn.ReLU, apply_act=True, drop_block=None): + super(SeparableConvBnAct, self).__init__() + norm_kwargs = norm_kwargs or {} + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + if self.bn is not None: + x = self.bn(x) + return x + + +class SeparableConv2d(nn.Module): + """ Separable Conv + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1): + super(SeparableConv2d, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + return x diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py index dcfc66ca..b2f3d2c3 100644 --- a/timm/models/layers/test_time_pool.py +++ b/timm/models/layers/test_time_pool.py @@ -6,6 +6,7 @@ Hacked together by Ross Wightman import logging from torch import nn import torch.nn.functional as F + from .adaptive_avgmax_pool import adaptive_avgmax_pool2d diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 8daebdf0..e1a700b0 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -7,13 +7,15 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Hacked together by Ross Wightman """ +import torch.nn as nn +import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_builder import * +from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT +from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .feature_hooks import FeatureHooks from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, create_conv2d -from .layers.activations import HardSwish, hard_sigmoid +from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid from .registry import register_model __all__ = ['MobileNetV3'] @@ -273,8 +275,8 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw head_bias=False, channel_multiplier=channel_multiplier, norm_kwargs=resolve_bn_args(kwargs), - act_layer=HardSwish, - se_kwargs=dict(gate_fn=hard_sigmoid, reduce_mid=True, divisor=1), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1), **kwargs, ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -293,7 +295,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg if 'small' in variant: num_features = 1024 if 'minimal' in variant: - act_layer = nn.ReLU + act_layer = resolve_act_layer(kwargs, 'relu') arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s2_e1_c16'], @@ -309,7 +311,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg ['cn_r1_k1_s1_c576'], ] else: - act_layer = HardSwish + act_layer = resolve_act_layer(kwargs, 'hard_swish') arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu @@ -327,7 +329,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg else: num_features = 1280 if 'minimal' in variant: - act_layer = nn.ReLU + act_layer = resolve_act_layer(kwargs, 'relu') arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16'], @@ -345,7 +347,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg ['cn_r1_k1_s1_c960'], ] else: - act_layer = HardSwish + act_layer = resolve_act_layer(kwargs, 'hard_swish') arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_nre'], # relu diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 97c2f86d..56614bd6 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -43,11 +43,12 @@ class MaxPool(nn.Module): self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding) def forward(self, x): - if self.zero_pad: + if self.zero_pad is not None: x = self.zero_pad(x) - x = self.pool(x) - if self.zero_pad: + x = self.pool(x) x = x[:, :, 1:, 1:] + else: + x = self.pool(x) return x @@ -90,11 +91,12 @@ class BranchSeparables(nn.Module): def forward(self, x): x = self.relu_1(x) - if self.zero_pad: + if self.zero_pad is not None: x = self.zero_pad(x) - x = self.separable_1(x) - if self.zero_pad: + x = self.separable_1(x) x = x[:, :, 1:, 1:].contiguous() + else: + x = self.separable_1(x) x = self.bn_sep_1(x) x = self.relu_2(x) x = self.separable_2(x) @@ -171,15 +173,14 @@ class CellBase(nn.Module): x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right x_comb_iter_4_left = self.comb_iter_4_left(x_left) - if self.comb_iter_4_right: + if self.comb_iter_4_right is not None: x_comb_iter_4_right = self.comb_iter_4_right(x_right) else: x_comb_iter_4_right = x_right x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right x_out = torch.cat( - [x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, - x_comb_iter_4], 1) + [x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) return x_out @@ -280,9 +281,8 @@ class Cell(CellBase): kernel_size=3, stride=stride, zero_pad=zero_pad) if is_reduction: - self.comb_iter_4_right = ReluConvBn(out_channels_right, - out_channels_right, - kernel_size=1, stride=stride) + self.comb_iter_4_right = ReluConvBn( + out_channels_right, out_channels_right, kernel_size=1, stride=stride) else: self.comb_iter_4_right = None diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 3e3882fe..b095de30 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -77,6 +77,8 @@ class Bottle2neck(nn.Module): if self.is_first: # FIXME this should probably have count_include_pad=False, but hurts original weights self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) + else: + self.pool = None self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) @@ -97,14 +99,22 @@ class Bottle2neck(nn.Module): spx = torch.split(out, self.width, 1) spo = [] + sp = spx[0] for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): - sp = spx[i] if i == 0 or self.is_first else sp + spx[i] + if self.is_first: + sp = spx[i] + else: + sp = sp + spx[i] sp = conv(sp) sp = bn(sp) sp = self.relu(sp) spo.append(sp) if self.scale > 1: - spo.append(self.pool(spx[-1]) if self.is_first else spx[-1]) + if self.pool is not None: + # self.is_first == True, None check for torchscript + spo.append(self.pool(spx[-1])) + else: + spo.append(spx[-1]) out = torch.cat(spo, 1) out = self.conv3(out) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 430bbb49..8750c5bd 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -200,7 +200,6 @@ class BasicBlock(nn.Module): class Bottleneck(nn.Module): - __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 77bdd2c9..b7573086 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -9,6 +9,7 @@ https://arxiv.org/abs/1907.00837 Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch """ +from typing import List import torch import torch.nn as nn @@ -52,6 +53,27 @@ default_cfgs = { } +class SequentialList(nn.Sequential): + + def __init__(self, *args): + super(SequentialList, self).__init__(*args) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (List[torch.Tensor]) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (List[torch.Tensor]) + pass + + def forward(self, x) -> List[torch.Tensor]: + for module in self: + x = module(x) + return x + + def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1): if padding is None: padding = ((stride - 1) + dilation * (k - 1)) // 2 @@ -77,7 +99,7 @@ class SelecSLSBlock(nn.Module): self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3) self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1) - def forward(self, x): + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: assert isinstance(x, list) assert len(x) in [1, 2] @@ -113,7 +135,7 @@ class SelecSLS(nn.Module): super(SelecSLS, self).__init__() self.stem = conv_bn(in_chans, 32, stride=2) - self.features = nn.Sequential(*[cfg['block'](*block_args) for block_args in cfg['features']]) + self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']]) self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']]) self.num_features = cfg['num_features'] diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 55a6e195..a4274b2f 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -13,15 +13,9 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import load_pretrained -from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d +from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d, InplaceAbn from .registry import register_model -try: - from inplace_abn import InPlaceABN - has_iabn = True -except ImportError: - has_iabn = False - __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] @@ -91,37 +85,37 @@ class FastSEModule(nn.Module): def IABN2Float(module: nn.Module) -> nn.Module: """If `module` is IABN don't use half precision.""" - if isinstance(module, InPlaceABN): + if isinstance(module, InplaceAbn): module.float() for child in module.children(): IABN2Float(child) return module -def conv2d_ABN(ni, nf, stride, activation="leaky_relu", kernel_size=3, activation_param=1e-2, groups=1): +def conv2d_iabn(ni, nf, stride, kernel_size=3, groups=1, act_layer="leaky_relu", act_param=1e-2): return nn.Sequential( nn.Conv2d( ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, bias=False), - InPlaceABN(num_features=nf, activation=activation, activation_param=activation_param) + InplaceAbn(nf, act_layer=act_layer, act_param=act_param) ) class BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None): + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_layer=None): super(BasicBlock, self).__init__() if stride == 1: - self.conv1 = conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3) + self.conv1 = conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3) else: - if anti_alias_layer is None: - self.conv1 = conv2d_ABN(inplanes, planes, stride=2, activation_param=1e-3) + if aa_layer is None: + self.conv1 = conv2d_iabn(inplanes, planes, stride=2, act_param=1e-3) else: self.conv1 = nn.Sequential( - conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3), - anti_alias_layer(channels=planes, filt_size=3, stride=2)) + conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3), + aa_layer(channels=planes, filt_size=3, stride=2)) - self.conv2 = conv2d_ABN(planes, planes, stride=1, activation="identity") + self.conv2 = conv2d_iabn(planes, planes, stride=1, act_layer="identity") self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride @@ -148,24 +142,25 @@ class BasicBlock(nn.Module): class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None): + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, + act_layer="leaky_relu", aa_layer=None): super(Bottleneck, self).__init__() - self.conv1 = conv2d_ABN( - inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu", activation_param=1e-3) + self.conv1 = conv2d_iabn( + inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, act_param=1e-3) if stride == 1: - self.conv2 = conv2d_ABN( - planes, planes, kernel_size=3, stride=1, activation="leaky_relu", activation_param=1e-3) + self.conv2 = conv2d_iabn( + planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3) else: - if anti_alias_layer is None: - self.conv2 = conv2d_ABN( - planes, planes, kernel_size=3, stride=2, activation="leaky_relu", activation_param=1e-3) + if aa_layer is None: + self.conv2 = conv2d_iabn( + planes, planes, kernel_size=3, stride=2, act_layer=act_layer, act_param=1e-3) else: self.conv2 = nn.Sequential( - conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu", activation_param=1e-3), - anti_alias_layer(channels=planes, filt_size=3, stride=2)) + conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3), + aa_layer(channels=planes, filt_size=3, stride=2)) - self.conv3 = conv2d_ABN( - planes, planes * self.expansion, kernel_size=1, stride=1, activation="identity") + self.conv3 = conv2d_iabn( + planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") self.relu = nn.ReLU(inplace=True) self.downsample = downsample @@ -195,30 +190,26 @@ class Bottleneck(nn.Module): class TResNet(nn.Module): def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False, global_pool='avg', drop_rate=0.): - if not has_iabn: - raise ImportError( - "For TResNet models, please install InplaceABN: " - "'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'") self.num_classes = num_classes self.drop_rate = drop_rate super(TResNet, self).__init__() # JIT layers space_to_depth = SpaceToDepthModule() - anti_alias_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit) + aa_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit) # TResnet stages self.inplanes = int(64 * width_factor) self.planes = int(64 * width_factor) - conv1 = conv2d_ABN(in_chans * 16, self.planes, stride=1, kernel_size=3) + conv1 = conv2d_iabn(in_chans * 16, self.planes, stride=1, kernel_size=3) layer1 = self._make_layer( - BasicBlock, self.planes, layers[0], stride=1, use_se=True, anti_alias_layer=anti_alias_layer) # 56x56 + BasicBlock, self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer) # 56x56 layer2 = self._make_layer( - BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, anti_alias_layer=anti_alias_layer) # 28x28 + BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer) # 28x28 layer3 = self._make_layer( - Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, anti_alias_layer=anti_alias_layer) # 14x14 + Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer) # 14x14 layer4 = self._make_layer( - Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, anti_alias_layer=anti_alias_layer) # 7x7 + Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer) # 7x7 # body self.body = nn.Sequential(OrderedDict([ @@ -239,7 +230,7 @@ class TResNet(nn.Module): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') - elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InPlaceABN): + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InplaceAbn): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -251,24 +242,24 @@ class TResNet(nn.Module): m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01) - def _make_layer(self, block, planes, blocks, stride=1, use_se=True, anti_alias_layer=None): + def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: layers = [] if stride == 2: # avg pooling before 1x1 conv layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) - layers += [conv2d_ABN( - self.inplanes, planes * block.expansion, kernel_size=1, stride=1, activation="identity")] + layers += [conv2d_iabn( + self.inplanes, planes * block.expansion, kernel_size=1, stride=1, act_layer="identity")] downsample = nn.Sequential(*layers) layers = [] layers.append(block( - self.inplanes, planes, stride, downsample, use_se=use_se, anti_alias_layer=anti_alias_layer)) + self.inplanes, planes, stride, downsample, use_se=use_se, aa_layer=aa_layer)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append( - block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer)) + block(self.inplanes, planes, use_se=use_se, aa_layer=aa_layer)) return nn.Sequential(*layers) def get_classifier(self): diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py new file mode 100644 index 00000000..bedff10c --- /dev/null +++ b/timm/models/vovnet.py @@ -0,0 +1,408 @@ +""" VoVNet (V1 & V2) + +Papers: +* `An Energy and GPU-Computation Efficient Backbone Network` - https://arxiv.org/abs/1904.09730 +* `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + +Looked at https://github.com/youngwanLEE/vovnet-detectron2 & +https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py +for some reference, rewrote most of the code. + +Hacked together by Ross Wightman +""" + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .registry import register_model +from .helpers import load_pretrained +from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, SelectAdaptivePool2d, \ + create_attn, create_norm_act, get_norm_act_layer + + +# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 & +# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py +model_cfgs = dict( + vovnet39a=dict( + stem_ch=[64, 64, 128], + stage_conv_ch=[128, 160, 192, 224], + stage_out_ch=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=False, + depthwise=False, + attn='', + ), + vovnet57a=dict( + stem_ch=[64, 64, 128], + stage_conv_ch=[128, 160, 192, 224], + stage_out_ch=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 4, 3], + residual=False, + depthwise=False, + attn='', + + ), + ese_vovnet19b_slim_dw=dict( + stem_ch=[64, 64, 64], + stage_conv_ch=[64, 80, 96, 112], + stage_out_ch=[112, 256, 384, 512], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=True, + attn='ese', + + ), + ese_vovnet19b_dw=dict( + stem_ch=[64, 64, 64], + stage_conv_ch=[128, 160, 192, 224], + stage_out_ch=[256, 512, 768, 1024], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=True, + attn='ese', + ), + ese_vovnet19b_slim=dict( + stem_ch=[64, 64, 128], + stage_conv_ch=[64, 80, 96, 112], + stage_out_ch=[112, 256, 384, 512], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=False, + attn='ese', + ), + ese_vovnet19b=dict( + stem_ch=[64, 64, 128], + stage_conv_ch=[128, 160, 192, 224], + stage_out_ch=[256, 512, 768, 1024], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=False, + attn='ese', + + ), + ese_vovnet39b=dict( + stem_ch=[64, 64, 128], + stage_conv_ch=[128, 160, 192, 224], + stage_out_ch=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=True, + depthwise=False, + attn='ese', + ), + ese_vovnet57b=dict( + stem_ch=[64, 64, 128], + stage_conv_ch=[128, 160, 192, 224], + stage_out_ch=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 4, 3], + residual=True, + depthwise=False, + attn='ese', + + ), + ese_vovnet99b=dict( + stem_ch=[64, 64, 128], + stage_conv_ch=[128, 160, 192, 224], + stage_out_ch=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 3, 9, 3], + residual=True, + depthwise=False, + attn='ese', + ), + eca_vovnet39b=dict( + stem_ch=[64, 64, 128], + stage_conv_ch=[128, 160, 192, 224], + stage_out_ch=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=True, + depthwise=False, + attn='eca', + ), +) + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', + } + + +default_cfgs = dict( + vovnet39a=_cfg(url=''), + vovnet57a=_cfg(url=''), + ese_vovnet19b_slim_dw=_cfg(url=''), + ese_vovnet19b_dw=_cfg(url=''), + ese_vovnet19b_slim=_cfg(url=''), + ese_vovnet39b=_cfg(url=''), + ese_vovnet57b=_cfg(url=''), + ese_vovnet99b=_cfg(url=''), + eca_vovnet39b=_cfg(url=''), +) + + +class SequentialAppendList(nn.Sequential): + def __init__(self, *args): + super(SequentialAppendList, self).__init__(*args) + + def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor: + for i, module in enumerate(self): + if i == 0: + concat_list.append(module(x)) + else: + concat_list.append(module(concat_list[-1])) + x = torch.cat(concat_list, dim=1) + return x + + +class OsaBlock(nn.Module): + + def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False, + depthwise=False, attn='', norm_layer=BatchNormAct2d): + super(OsaBlock, self).__init__() + + self.residual = residual + self.depthwise = depthwise + + next_in_chs = in_chs + if self.depthwise and next_in_chs != mid_chs: + assert not residual + self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, norm_layer=norm_layer) + else: + self.conv_reduction = None + + mid_convs = [] + for i in range(layer_per_block): + if self.depthwise: + conv = SeparableConvBnAct(mid_chs, mid_chs, norm_layer=norm_layer) + else: + conv = ConvBnAct(next_in_chs, mid_chs, 3, norm_layer=norm_layer) + next_in_chs = mid_chs + mid_convs.append(conv) + self.conv_mid = SequentialAppendList(*mid_convs) + + # feature aggregation + next_in_chs = in_chs + layer_per_block * mid_chs + self.conv_concat = ConvBnAct(next_in_chs, out_chs, norm_layer=norm_layer) + + if attn: + self.attn = create_attn(attn, out_chs) + else: + self.attn = None + + def forward(self, x): + output = [x] + if self.conv_reduction is not None: + x = self.conv_reduction(x) + x = self.conv_mid(x, output) + x = self.conv_concat(x) + if self.attn is not None: + x = self.attn(x) + if self.residual: + x = x + output[0] + return x + + +class OsaStage(nn.Module): + + def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, + downsample=True, residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d): + super(OsaStage, self).__init__() + + if downsample: + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) + else: + self.pool = None + + blocks = [] + for i in range(block_per_stage): + last_block = i == block_per_stage - 1 + blocks += [OsaBlock( + in_chs if i == 0 else out_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, + depthwise=depthwise, attn=attn if last_block else '', norm_layer=norm_layer) + ] + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + if self.pool is not None: + x = self.pool(x) + x = self.blocks(x) + return x + + +class ClassifierHead(nn.Module): + """Head.""" + + def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): + super(ClassifierHead, self).__init__() + self.drop_rate = drop_rate + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + if num_classes > 0: + self.fc = nn.Linear(in_chs, num_classes, bias=True) + else: + self.fc = nn.Identity() + + def forward(self, x): + x = self.global_pool(x).flatten(1) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + return x + + +class VovNet(nn.Module): + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, + norm_layer=BatchNormAct2d): + """ VovNet (v2) + """ + super(VovNet, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert stem_stride in (4, 2) + + stem_ch = cfg["stem_ch"] + stage_conv_ch = cfg["stage_conv_ch"] + stage_out_ch = cfg["stage_out_ch"] + block_per_stage = cfg["block_per_stage"] + layer_per_block = cfg["layer_per_block"] + + # Stem module + last_stem_stride = stem_stride // 2 + conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct + self.stem = nn.Sequential(*[ + ConvBnAct(in_chans, stem_ch[0], 3, stride=2, norm_layer=norm_layer), + conv_type(stem_ch[0], stem_ch[1], 3, stride=1, norm_layer=norm_layer), + conv_type(stem_ch[1], stem_ch[2], 3, stride=last_stem_stride, norm_layer=norm_layer), + ]) + + # OSA stages + in_ch_list = stem_ch[-1:] + stage_out_ch[:-1] + stage_args = dict( + residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], norm_layer=norm_layer) + stages = [] + for i in range(4): # num_stages + downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4 + stages += [OsaStage( + in_ch_list[i], stage_conv_ch[i], stage_out_ch[i], block_per_stage[i], layer_per_block, + downsample=downsample, **stage_args) + ] + self.num_features = stage_out_ch[i] + self.stages = nn.Sequential(*stages) + + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + return self.stages(x) + + def forward(self, x): + x = self.forward_features(x) + return self.head(x) + + +def _vovnet(variant, pretrained=False, **kwargs): + load_strict = True + model_class = VovNet + if kwargs.pop('features_only', False): + assert False, 'Not Implemented' # TODO + load_strict = False + kwargs.pop('num_classes', 0) + model_cfg = model_cfgs[variant] + default_cfg = default_cfgs[variant] + model = model_class(model_cfg, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, default_cfg, + num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict) + return model + + + +@register_model +def vovnet39a(pretrained=False, **kwargs): + return _vovnet('vovnet39a', pretrained=pretrained, **kwargs) + + +@register_model +def vovnet57a(pretrained=False, **kwargs): + return _vovnet('vovnet57a', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_slim_dw(pretrained=False, **kwargs): + return _vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_dw(pretrained=False, **kwargs): + return _vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_slim(pretrained=False, **kwargs): + return _vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet39b(pretrained=False, **kwargs): + return _vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet57b(pretrained=False, **kwargs): + return _vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet99b(pretrained=False, **kwargs): + return _vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs) + + +@register_model +def eca_vovnet39b(pretrained=False, **kwargs): + return _vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs) + + +# Experimental Models + +@register_model +def ese_vovnet39b_iabn(pretrained=False, **kwargs): + norm_layer = get_norm_act_layer('iabn') + return _vovnet('ese_vovnet39b', pretrained=pretrained, norm_layer=norm_layer, **kwargs) + + +@register_model +def ese_vovnet39b_evos(pretrained=False, **kwargs): + def norm_act_fn(num_features, **kwargs): + return create_norm_act('EvoNormSample', num_features, jit=False, **kwargs) + return _vovnet('ese_vovnet39b', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) diff --git a/validate.py b/validate.py index f8ac7c55..ca031263 100755 --- a/validate.py +++ b/validate.py @@ -24,7 +24,8 @@ try: except ImportError: has_apex = False -from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models +from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models,\ + set_scriptable, set_no_jit from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging @@ -84,6 +85,9 @@ def validate(args): args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher + if args.torchscript: + set_scriptable(True) + # create model model = create_model( args.model, @@ -141,8 +145,10 @@ def validate(args): top5 = AverageMeter() model.eval() - end = time.time() with torch.no_grad(): + # warmup, reduce variability of first batch time, especially for comparing torchscript vs non + model(torch.randn((args.batch_size,) + data_config['input_size']).cuda()) + end = time.time() for i, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda()