diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 4b774c4a..420680f9 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -13,7 +13,7 @@ from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d from .registry import register_model __all__ = ['DenseNet'] @@ -35,90 +35,88 @@ default_cfgs = { 'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'), 'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'), 'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'), + 'densenet264': _cfg(url=''), } -class _DenseLayer(nn.Module): - def __init__(self, num_input_features, growth_rate, bn_size, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, +class DenseLayer(nn.Module): + def __init__(self, num_input_features, growth_rate, bn_size, norm_act_layer=BatchNormAct2d, drop_rate=0., memory_efficient=False): - super(_DenseLayer, self).__init__() - self.add_module('norm1', norm_layer(num_input_features)), - self.add_module('relu1', act_layer(inplace=True)), + super(DenseLayer, self).__init__() + self.add_module('norm1', norm_act_layer(num_input_features)), self.add_module('conv1', nn.Conv2d( num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), - self.add_module('norm2', norm_layer(bn_size * growth_rate)), - self.add_module('relu2', act_layer(inplace=True)), + self.add_module('norm2', norm_act_layer(bn_size * growth_rate)), self.add_module('conv2', nn.Conv2d( bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), self.drop_rate = float(drop_rate) self.memory_efficient = memory_efficient - def bn_function(self, inputs): + def bottleneck_fn(self, xs): # type: (List[torch.Tensor]) -> torch.Tensor - concated_features = torch.cat(inputs, 1) - bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 + concated_features = torch.cat(xs, 1) + bottleneck_output = self.conv1(self.norm1(concated_features)) # noqa: T484 return bottleneck_output # todo: rewrite when torchscript supports any - def any_requires_grad(self, input): + def any_requires_grad(self, x): # type: (List[torch.Tensor]) -> bool - for tensor in input: + for tensor in x: if tensor.requires_grad: return True return False @torch.jit.unused # noqa: T484 - def call_checkpoint_bottleneck(self, input): + def call_checkpoint_bottleneck(self, x): # type: (List[torch.Tensor]) -> torch.Tensor - def closure(*inputs): - return self.bn_function(*inputs) + def closure(*xs): + return self.bottleneck_fn(*xs) - return cp.checkpoint(closure, input) + return cp.checkpoint(closure, x) @torch.jit._overload_method # noqa: F811 - def forward(self, input): + def forward(self, x): # type: (List[torch.Tensor]) -> (torch.Tensor) pass @torch.jit._overload_method # noqa: F811 - def forward(self, input): + def forward(self, x): # type: (torch.Tensor) -> (torch.Tensor) pass # torchscript does not yet support *args, so we overload method # allowing it to take either a List[Tensor] or single Tensor - def forward(self, input): # noqa: F811 - if isinstance(input, torch.Tensor): - prev_features = [input] + def forward(self, x): # noqa: F811 + if isinstance(x, torch.Tensor): + prev_features = [x] else: - prev_features = input + prev_features = x if self.memory_efficient and self.any_requires_grad(prev_features): if torch.jit.is_scripting(): raise Exception("Memory Efficient not supported in JIT") bottleneck_output = self.call_checkpoint_bottleneck(prev_features) else: - bottleneck_output = self.bn_function(prev_features) + bottleneck_output = self.bottleneck_fn(prev_features) - new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) + new_features = self.conv2(self.norm2(bottleneck_output)) if self.drop_rate > 0: new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) return new_features -class _DenseBlock(nn.ModuleDict): +class DenseBlock(nn.ModuleDict): _version = 2 - def __init__(self, num_layers, num_input_features, bn_size, growth_rate, act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, drop_rate=0., memory_efficient=False): - super(_DenseBlock, self).__init__() + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_act_layer=nn.ReLU, + drop_rate=0., memory_efficient=False): + super(DenseBlock, self).__init__() for i in range(num_layers): - layer = _DenseLayer( + layer = DenseLayer( num_input_features + i * growth_rate, growth_rate=growth_rate, bn_size=bn_size, - act_layer=act_layer, - norm_layer=norm_layer, + norm_act_layer=norm_act_layer, drop_rate=drop_rate, memory_efficient=memory_efficient, ) @@ -132,11 +130,10 @@ class _DenseBlock(nn.ModuleDict): return torch.cat(features, 1) -class _Transition(nn.Sequential): - def __init__(self, num_input_features, num_output_features, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(_Transition, self).__init__() - self.add_module('norm', norm_layer(num_input_features)) - self.add_module('relu', act_layer(inplace=True)) +class DenseTransition(nn.Sequential): + def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d): + super(DenseTransition, self).__init__() + self.add_module('norm', norm_act_layer(num_input_features)) self.add_module('conv', nn.Conv2d( num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) @@ -149,7 +146,6 @@ class DenseNet(nn.Module): Args: growth_rate (int) - how many filters to add each layer (`k` in paper) block_config (list of 4 ints) - how many layers in each pooling block - num_init_features (int) - the number of filters to learn in the first convolution layer bn_size (int) - multiplicative factor for number of bottle neck layers (i.e. bn_size * k features in the bottleneck layer) drop_rate (float) - dropout rate after each dense layer @@ -158,67 +154,66 @@ class DenseNet(nn.Module): but slower. Default: *False*. See `"paper" `_ """ - def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, - bn_size=4, stem_type='', num_classes=1000, in_chans=3, global_pool='avg', - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0, memory_efficient=False): + def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='', + num_classes=1000, in_chans=3, global_pool='avg', + norm_act_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False): self.num_classes = num_classes self.drop_rate = drop_rate - deep_stem = 'deep' in stem_type super(DenseNet, self).__init__() - # First convolution + # Stem + deep_stem = 'deep' in stem_type # 3x3 deep stem + num_init_features = growth_rate * 2 if aa_layer is None: - max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: - max_pool = nn.Sequential(*[ + stem_pool = nn.Sequential(*[ nn.MaxPool2d(kernel_size=3, stride=1, padding=1), - aa_layer(channels=self.inplanes, stride=2)]) + aa_layer(channels=num_init_features, stride=2)]) if deep_stem: - stem_chs_1 = stem_chs_2 = num_init_features // 2 + stem_chs_1 = stem_chs_2 = growth_rate if 'tiered' in stem_type: - stem_chs_1 = 3 * (num_init_features // 8) - stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (num_init_features // 8) + stem_chs_1 = 3 * (growth_rate // 4) + stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4) self.features = nn.Sequential(OrderedDict([ ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)), - ('norm0', norm_layer(stem_chs_1)), - ('relu0', act_layer(inplace=True)), + ('norm0', norm_act_layer(stem_chs_1)), ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)), - ('norm1', norm_layer(stem_chs_2)), - ('relu1', act_layer(inplace=True)), + ('norm1', norm_act_layer(stem_chs_2)), ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)), - ('norm2', norm_layer(num_init_features)), - ('relu2', act_layer(inplace=True)), - ('pool0', max_pool), + ('norm2', norm_act_layer(num_init_features)), + ('pool0', stem_pool), ])) else: self.features = nn.Sequential(OrderedDict([ ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), - ('norm0', norm_layer(num_init_features)), - ('relu0', act_layer(inplace=True)), - ('pool0', max_pool), + ('norm0', norm_act_layer(num_init_features)), + ('pool0', stem_pool), ])) - # Each denseblock + # DenseBlocks num_features = num_init_features for i, num_layers in enumerate(block_config): - block = _DenseBlock( + block = DenseBlock( num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, + norm_act_layer=norm_act_layer, drop_rate=drop_rate, memory_efficient=memory_efficient ) self.features.add_module('denseblock%d' % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: - trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) + trans = DenseTransition( + num_input_features=num_features, num_output_features=num_features // 2, + norm_act_layer=norm_act_layer) self.features.add_module('transition%d' % (i + 1), trans) num_features = num_features // 2 # Final batch norm - self.features.add_module('norm5', norm_layer(num_features)) - self.act = act_layer(inplace=True) + self.features.add_module('norm5', norm_act_layer(num_features)) # Linear layer self.num_features = num_features @@ -248,9 +243,7 @@ class DenseNet(nn.Module): self.classifier = nn.Identity() def forward_features(self, x): - x = self.features(x) - x = self.act(x) - return x + return self.features(x) def forward(self, x): x = self.forward_features(x) @@ -275,7 +268,7 @@ def _filter_torchvision_pretrained(state_dict): return state_dict -def _densenet(variant, growth_rate, block_config, num_init_features, pretrained, **kwargs): +def _densenet(variant, growth_rate, block_config, pretrained, **kwargs): if kwargs.pop('features_only', False): assert False, 'Not Implemented' # TODO load_strict = False @@ -285,8 +278,7 @@ def _densenet(variant, growth_rate, block_config, num_init_features, pretrained, load_strict = True model_class = DenseNet default_cfg = default_cfgs[variant] - model = model_class( - growth_rate=growth_rate, block_config=block_config, num_init_features=num_init_features, **kwargs) + model = model_class(growth_rate=growth_rate, block_config=block_config, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained( @@ -304,8 +296,7 @@ def densenet121(pretrained=False, **kwargs): `"Densely Connected Convolutional Networks" ` """ model = _densenet( - 'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, - pretrained=pretrained, **kwargs) + 'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs) return model @@ -315,8 +306,8 @@ def densenet121d(pretrained=False, **kwargs): `"Densely Connected Convolutional Networks" ` """ model = _densenet( - 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, - stem_type='deep', pretrained=pretrained, **kwargs) + 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', + pretrained=pretrained, **kwargs) return model @@ -326,8 +317,42 @@ def densenet121tn(pretrained=False, **kwargs): `"Densely Connected Convolutional Networks" ` """ model = _densenet( - 'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, - stem_type='deep_tiered_narrow', pretrained=pretrained, **kwargs) + 'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep_tiered_narrow', + pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet121d_evob(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _densenet( + 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', + norm_act_layer=EvoNormBatch2d, pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet121d_evos(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _densenet( + 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', + norm_act_layer=EvoNormSample2d, pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet121d_iabn(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + from inplace_abn import InPlaceABN + model = _densenet( + 'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', + norm_act_layer=InPlaceABN, pretrained=pretrained, **kwargs) return model @@ -337,8 +362,7 @@ def densenet169(pretrained=False, **kwargs): `"Densely Connected Convolutional Networks" ` """ model = _densenet( - 'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), num_init_features=64, - pretrained=pretrained, **kwargs) + 'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), pretrained=pretrained, **kwargs) return model @@ -348,17 +372,25 @@ def densenet201(pretrained=False, **kwargs): `"Densely Connected Convolutional Networks" ` """ model = _densenet( - 'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), num_init_features=64, - pretrained=pretrained, **kwargs) + 'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), pretrained=pretrained, **kwargs) return model @register_model def densenet161(pretrained=False, **kwargs): - r"""Densenet-201 model from + r"""Densenet-161 model from `"Densely Connected Convolutional Networks" ` """ model = _densenet( - 'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), num_init_features=96, - pretrained=pretrained, **kwargs) + 'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet264(pretrained=False, **kwargs): + r"""Densenet-264 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _densenet( + 'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs) return model diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 4f84bb9e..12e7326e 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -19,3 +19,5 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .anti_aliasing import AntiAliasDownsampleLayer from .space_to_depth import SpaceToDepthModule from .blur_pool import BlurPool2d +from .norm_act import BatchNormAct2d +from .evo_norm import EvoNormBatch2d, EvoNormSample2d \ No newline at end of file diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py new file mode 100644 index 00000000..79de23e9 --- /dev/null +++ b/timm/models/layers/evo_norm.py @@ -0,0 +1,134 @@ +"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch + +An attempt at getting decent performing EvoNorms running in PyTorch. +While currently faster than other impl, still quite a ways off the built-in BN +in terms of memory usage and throughput. + +Still very much a WIP, fiddling with buffer usage, in-place optimizations, and layouts. + +Hacked together by Ross Wightman +""" + +import torch +import torch.nn as nn + + +@torch.jit.script +def evo_batch_jit( + x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, running_var: torch.Tensor, + momentum: float, training: bool, nonlin: bool, eps: float): + x_type = x.dtype + running_var = running_var.detach() # FIXME why is this needed, it's a buffer? + if training: + var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) # FIXME biased, unbiased? + running_var.copy_(momentum * var + (1 - momentum) * running_var) + else: + var = running_var.clone() + + if nonlin: + # FIXME biased, unbiased? + d = (x * v.to(x_type)) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(eps).sqrt_().to(dtype=x_type) + d = d.max(var.add(eps).sqrt_().to(dtype=x_type)) + x = x / d + return x.mul_(weight).add_(bias) + else: + return x.mul(weight).add_(bias) + + +class EvoNormBatch2d(nn.Module): + def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5, jit=True): + super(EvoNormBatch2d, self).__init__() + self.momentum = momentum + self.nonlin = nonlin + self.eps = eps + self.jit = jit + param_shape = (1, num_features, 1, 1) + self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) + if nonlin: + self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.nonlin: + nn.init.ones_(self.v) + + def forward(self, x): + assert x.dim() == 4, 'expected 4D input' + + if self.jit: + return evo_batch_jit( + x, self.v, self.weight, self.bias, self.running_var, self.momentum, + self.training, self.nonlin, self.eps) + else: + x_type = x.dtype + if self.training: + var = x.var(dim=(0, 2, 3), keepdim=True) + self.running_var.copy_(self.momentum * var + (1 - self.momentum) * self.running_var) + else: + var = self.running_var.clone() + + if self.nonlin: + v = self.v.to(dtype=x_type) + d = (x * v) + x.var(dim=(2, 3), keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type) + d = d.max(var.add(self.eps).sqrt_().to(dtype=x_type)) + x = x / d + return x.mul_(self.weight).add_(self.bias) + else: + return x.mul(self.weight).add_(self.bias) + + +@torch.jit.script +def evo_sample_jit( + x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, + groups: int, nonlin: bool, eps: float): + B, C, H, W = x.shape + assert C % groups == 0 + if nonlin: + n = (x * v).sigmoid_().reshape(B, groups, -1) + x = x.reshape(B, groups, -1) + x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(eps).sqrt_() + x = x.reshape(B, C, H, W) + return x.mul_(weight).add_(bias) + + +class EvoNormSample2d(nn.Module): + def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5, jit=True): + super(EvoNormSample2d, self).__init__() + self.nonlin = nonlin + self.groups = groups + self.eps = eps + self.jit = jit + param_shape = (1, num_features, 1, 1) + self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) + if nonlin: + self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.nonlin: + nn.init.ones_(self.v) + + def forward(self, x): + assert x.dim() == 4, 'expected 4D input' + + if self.jit: + return evo_sample_jit( + x, self.v, self.weight, self.bias, self.groups, self.nonlin, self.eps) + else: + B, C, H, W = x.shape + assert C % self.groups == 0 + if self.nonlin: + n = (x * self.v).sigmoid().reshape(B, self.groups, -1) + x = x.reshape(B, self.groups, -1) + x = n / (x.std(dim=-1, unbiased=False, keepdim=True) + self.eps) + x = x.reshape(B, C, H, W) + return x.mul_(self.weight).add_(self.bias) + else: + return x.mul(self.weight).add_(self.bias) diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py new file mode 100644 index 00000000..879a8939 --- /dev/null +++ b/timm/models/layers/norm_act.py @@ -0,0 +1,50 @@ +""" Normalization + Activation Layers +""" +from torch import nn as nn +from torch.nn import functional as F + + +class BatchNormAct2d(nn.BatchNorm2d): + """BatchNorm + Activation + + This module performs BatchNorm + Actibation in s manner that will remain bavkwards + compatible with weights trained with separate bn, act. This is why we inherit from BN + instead of composing it as a .bn member. + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True, act_layer=nn.ReLU, inplace=True): + super(BatchNormAct2d, self).__init__(num_features, eps, momentum, affine, track_running_stats) + self.act = act_layer(inplace=inplace) + + def forward(self, x): + # FIXME cannot call parent forward() and maintain jit.script compatibility? + # x = super(BatchNormAct2d, self).forward(x) + + # BEGIN nn.BatchNorm2d forward() cut & paste + # self._check_input_dim(x) + + # exponential_average_factor is self.momentum set to + # (when it is available) only so that if gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + x = F.batch_norm( + x, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, self.eps) + # END BatchNorm2d forward() + + x = self.act(x) + return x