diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 420680f9..b9f9853c 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -4,6 +4,7 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool. """ import re from collections import OrderedDict +from functools import partial import torch import torch.nn as nn @@ -13,7 +14,7 @@ from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d +from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act from .registry import register_model __all__ = ['DenseNet'] @@ -327,9 +328,11 @@ def densenet121d_evob(pretrained=False, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" ` """ + def norm_act_fn(num_features, **kwargs): + return create_norm_act('EvoNormBatch', num_features, jit=True, **kwargs) model = _densenet( 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', - norm_act_layer=EvoNormBatch2d, pretrained=pretrained, **kwargs) + norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs) return model @@ -338,9 +341,11 @@ def densenet121d_evos(pretrained=False, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" ` """ + def norm_act_fn(num_features, **kwargs): + return create_norm_act('EvoNormSample', num_features, jit=True, **kwargs) model = _densenet( 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', - norm_act_layer=EvoNormSample2d, pretrained=pretrained, **kwargs) + norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs) return model @@ -349,10 +354,11 @@ def densenet121d_iabn(pretrained=False, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" ` """ - from inplace_abn import InPlaceABN + def norm_act_fn(num_features, **kwargs): + return create_norm_act('iabn', num_features, **kwargs) model = _densenet( 'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', - norm_act_layer=InPlaceABN, pretrained=pretrained, **kwargs) + norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs) return model diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 12e7326e..94c98fdc 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -20,4 +20,5 @@ from .anti_aliasing import AntiAliasDownsampleLayer from .space_to_depth import SpaceToDepthModule from .blur_pool import BlurPool2d from .norm_act import BatchNormAct2d -from .evo_norm import EvoNormBatch2d, EvoNormSample2d \ No newline at end of file +from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .create_norm_act import create_norm_act diff --git a/timm/models/layers/create_norm_act.py b/timm/models/layers/create_norm_act.py new file mode 100644 index 00000000..251c0c17 --- /dev/null +++ b/timm/models/layers/create_norm_act.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn + +from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .norm_act import BatchNormAct2d +try: + from inplace_abn import InPlaceABN + has_iabn = True +except ImportError: + has_iabn = False + + +def create_norm_act(layer_type, num_features, jit=False, **kwargs): + layer_parts = layer_type.split('_') + assert len(layer_parts) in (1, 2) + layer_class = layer_parts[0].lower() + #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection + + if layer_class == "batchnormact": + layer = BatchNormAct2d(num_features, **kwargs) # defaults to RELU of no kwargs override + elif layer_class == "batchnormrelu": + assert 'act_layer' not in kwargs + layer = BatchNormAct2d(num_features, act_layer=nn.ReLU, **kwargs) + elif layer_class == "evonormbatch": + layer = EvoNormBatch2d(num_features, **kwargs) + elif layer_class == "evonormsample": + layer = EvoNormSample2d(num_features, **kwargs) + elif layer_class == "iabn" or layer_class == "inplaceabn": + if not has_iabn: + raise ImportError( + "Pplease install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'") + layer = InPlaceABN(num_features, **kwargs) + else: + assert False, "Invalid norm_act layer (%s)" % layer_class + if jit: + layer = torch.jit.script(layer) + return layer diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 79de23e9..62d49428 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -13,35 +13,12 @@ import torch import torch.nn as nn -@torch.jit.script -def evo_batch_jit( - x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, running_var: torch.Tensor, - momentum: float, training: bool, nonlin: bool, eps: float): - x_type = x.dtype - running_var = running_var.detach() # FIXME why is this needed, it's a buffer? - if training: - var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) # FIXME biased, unbiased? - running_var.copy_(momentum * var + (1 - momentum) * running_var) - else: - var = running_var.clone() - - if nonlin: - # FIXME biased, unbiased? - d = (x * v.to(x_type)) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(eps).sqrt_().to(dtype=x_type) - d = d.max(var.add(eps).sqrt_().to(dtype=x_type)) - x = x / d - return x.mul_(weight).add_(bias) - else: - return x.mul(weight).add_(bias) - - class EvoNormBatch2d(nn.Module): - def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5, jit=True): + def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5): super(EvoNormBatch2d, self).__init__() self.momentum = momentum self.nonlin = nonlin self.eps = eps - self.jit = jit param_shape = (1, num_features, 1, 1) self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) @@ -58,50 +35,29 @@ class EvoNormBatch2d(nn.Module): def forward(self, x): assert x.dim() == 4, 'expected 4D input' - - if self.jit: - return evo_batch_jit( - x, self.v, self.weight, self.bias, self.running_var, self.momentum, - self.training, self.nonlin, self.eps) + x_type = x.dtype + if self.training: + var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) + self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var) else: - x_type = x.dtype - if self.training: - var = x.var(dim=(0, 2, 3), keepdim=True) - self.running_var.copy_(self.momentum * var + (1 - self.momentum) * self.running_var) - else: - var = self.running_var.clone() - - if self.nonlin: - v = self.v.to(dtype=x_type) - d = (x * v) + x.var(dim=(2, 3), keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type) - d = d.max(var.add(self.eps).sqrt_().to(dtype=x_type)) - x = x / d - return x.mul_(self.weight).add_(self.bias) - else: - return x.mul(self.weight).add_(self.bias) + var = self.running_var.clone() - -@torch.jit.script -def evo_sample_jit( - x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, - groups: int, nonlin: bool, eps: float): - B, C, H, W = x.shape - assert C % groups == 0 - if nonlin: - n = (x * v).sigmoid_().reshape(B, groups, -1) - x = x.reshape(B, groups, -1) - x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(eps).sqrt_() - x = x.reshape(B, C, H, W) - return x.mul_(weight).add_(bias) + if self.nonlin: + v = self.v.to(dtype=x_type) + d = (x * v) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type) + d = d.max(var.add_(self.eps).sqrt_().to(dtype=x_type)) + x = x / d + return x.mul_(self.weight).add_(self.bias) + else: + return x.mul(self.weight).add_(self.bias) class EvoNormSample2d(nn.Module): - def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5, jit=True): + def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5): super(EvoNormSample2d, self).__init__() self.nonlin = nonlin self.groups = groups self.eps = eps - self.jit = jit param_shape = (1, num_features, 1, 1) self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) @@ -117,18 +73,13 @@ class EvoNormSample2d(nn.Module): def forward(self, x): assert x.dim() == 4, 'expected 4D input' - - if self.jit: - return evo_sample_jit( - x, self.v, self.weight, self.bias, self.groups, self.nonlin, self.eps) + B, C, H, W = x.shape + assert C % self.groups == 0 + if self.nonlin: + n = (x * self.v).sigmoid().reshape(B, self.groups, -1) + x = x.reshape(B, self.groups, -1) + x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(self.eps).sqrt_() + x = x.reshape(B, C, H, W) + return x.mul_(self.weight).add_(self.bias) else: - B, C, H, W = x.shape - assert C % self.groups == 0 - if self.nonlin: - n = (x * self.v).sigmoid().reshape(B, self.groups, -1) - x = x.reshape(B, self.groups, -1) - x = n / (x.std(dim=-1, unbiased=False, keepdim=True) + self.eps) - x = x.reshape(B, C, H, W) - return x.mul_(self.weight).add_(self.bias) - else: - return x.mul(self.weight).add_(self.bias) + return x.mul(self.weight).add_(self.bias)