DenseNet converted to support ABN (norm + act) modules. Experimenting with EvoNorm, IABN

pull/155/head
Ross Wightman 4 years ago
parent 022ed001f3
commit 14edacdf9a

@ -13,7 +13,7 @@ from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d
from .registry import register_model
__all__ = ['DenseNet']
@ -35,90 +35,88 @@ default_cfgs = {
'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'),
'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'),
'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'),
'densenet264': _cfg(url=''),
}
class _DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
class DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, norm_act_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
super(_DenseLayer, self).__init__()
self.add_module('norm1', norm_layer(num_input_features)),
self.add_module('relu1', act_layer(inplace=True)),
super(DenseLayer, self).__init__()
self.add_module('norm1', norm_act_layer(num_input_features)),
self.add_module('conv1', nn.Conv2d(
num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
self.add_module('norm2', norm_layer(bn_size * growth_rate)),
self.add_module('relu2', act_layer(inplace=True)),
self.add_module('norm2', norm_act_layer(bn_size * growth_rate)),
self.add_module('conv2', nn.Conv2d(
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
def bn_function(self, inputs):
def bottleneck_fn(self, xs):
# type: (List[torch.Tensor]) -> torch.Tensor
concated_features = torch.cat(inputs, 1)
bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
concated_features = torch.cat(xs, 1)
bottleneck_output = self.conv1(self.norm1(concated_features)) # noqa: T484
return bottleneck_output
# todo: rewrite when torchscript supports any
def any_requires_grad(self, input):
def any_requires_grad(self, x):
# type: (List[torch.Tensor]) -> bool
for tensor in input:
for tensor in x:
if tensor.requires_grad:
return True
return False
@torch.jit.unused # noqa: T484
def call_checkpoint_bottleneck(self, input):
def call_checkpoint_bottleneck(self, x):
# type: (List[torch.Tensor]) -> torch.Tensor
def closure(*inputs):
return self.bn_function(*inputs)
def closure(*xs):
return self.bottleneck_fn(*xs)
return cp.checkpoint(closure, input)
return cp.checkpoint(closure, x)
@torch.jit._overload_method # noqa: F811
def forward(self, input):
def forward(self, x):
# type: (List[torch.Tensor]) -> (torch.Tensor)
pass
@torch.jit._overload_method # noqa: F811
def forward(self, input):
def forward(self, x):
# type: (torch.Tensor) -> (torch.Tensor)
pass
# torchscript does not yet support *args, so we overload method
# allowing it to take either a List[Tensor] or single Tensor
def forward(self, input): # noqa: F811
if isinstance(input, torch.Tensor):
prev_features = [input]
def forward(self, x): # noqa: F811
if isinstance(x, torch.Tensor):
prev_features = [x]
else:
prev_features = input
prev_features = x
if self.memory_efficient and self.any_requires_grad(prev_features):
if torch.jit.is_scripting():
raise Exception("Memory Efficient not supported in JIT")
bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
else:
bottleneck_output = self.bn_function(prev_features)
bottleneck_output = self.bottleneck_fn(prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
new_features = self.conv2(self.norm2(bottleneck_output))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return new_features
class _DenseBlock(nn.ModuleDict):
class DenseBlock(nn.ModuleDict):
_version = 2
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, drop_rate=0., memory_efficient=False):
super(_DenseBlock, self).__init__()
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_act_layer=nn.ReLU,
drop_rate=0., memory_efficient=False):
super(DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
layer = DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
act_layer=act_layer,
norm_layer=norm_layer,
norm_act_layer=norm_act_layer,
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
@ -132,11 +130,10 @@ class _DenseBlock(nn.ModuleDict):
return torch.cat(features, 1)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(_Transition, self).__init__()
self.add_module('norm', norm_layer(num_input_features))
self.add_module('relu', act_layer(inplace=True))
class DenseTransition(nn.Sequential):
def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d):
super(DenseTransition, self).__init__()
self.add_module('norm', norm_act_layer(num_input_features))
self.add_module('conv', nn.Conv2d(
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
@ -149,7 +146,6 @@ class DenseNet(nn.Module):
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
@ -158,67 +154,66 @@ class DenseNet(nn.Module):
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
bn_size=4, stem_type='', num_classes=1000, in_chans=3, global_pool='avg',
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0, memory_efficient=False):
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='',
num_classes=1000, in_chans=3, global_pool='avg',
norm_act_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False):
self.num_classes = num_classes
self.drop_rate = drop_rate
deep_stem = 'deep' in stem_type
super(DenseNet, self).__init__()
# First convolution
# Stem
deep_stem = 'deep' in stem_type # 3x3 deep stem
num_init_features = growth_rate * 2
if aa_layer is None:
max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else:
max_pool = nn.Sequential(*[
stem_pool = nn.Sequential(*[
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
aa_layer(channels=self.inplanes, stride=2)])
aa_layer(channels=num_init_features, stride=2)])
if deep_stem:
stem_chs_1 = stem_chs_2 = num_init_features // 2
stem_chs_1 = stem_chs_2 = growth_rate
if 'tiered' in stem_type:
stem_chs_1 = 3 * (num_init_features // 8)
stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (num_init_features // 8)
stem_chs_1 = 3 * (growth_rate // 4)
stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4)
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)),
('norm0', norm_layer(stem_chs_1)),
('relu0', act_layer(inplace=True)),
('norm0', norm_act_layer(stem_chs_1)),
('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)),
('norm1', norm_layer(stem_chs_2)),
('relu1', act_layer(inplace=True)),
('norm1', norm_act_layer(stem_chs_2)),
('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)),
('norm2', norm_layer(num_init_features)),
('relu2', act_layer(inplace=True)),
('pool0', max_pool),
('norm2', norm_act_layer(num_init_features)),
('pool0', stem_pool),
]))
else:
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('norm0', norm_layer(num_init_features)),
('relu0', act_layer(inplace=True)),
('pool0', max_pool),
('norm0', norm_act_layer(num_init_features)),
('pool0', stem_pool),
]))
# Each denseblock
# DenseBlocks
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
block = DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
norm_act_layer=norm_act_layer,
drop_rate=drop_rate,
memory_efficient=memory_efficient
)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
trans = DenseTransition(
num_input_features=num_features, num_output_features=num_features // 2,
norm_act_layer=norm_act_layer)
self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module('norm5', norm_layer(num_features))
self.act = act_layer(inplace=True)
self.features.add_module('norm5', norm_act_layer(num_features))
# Linear layer
self.num_features = num_features
@ -248,9 +243,7 @@ class DenseNet(nn.Module):
self.classifier = nn.Identity()
def forward_features(self, x):
x = self.features(x)
x = self.act(x)
return x
return self.features(x)
def forward(self, x):
x = self.forward_features(x)
@ -275,7 +268,7 @@ def _filter_torchvision_pretrained(state_dict):
return state_dict
def _densenet(variant, growth_rate, block_config, num_init_features, pretrained, **kwargs):
def _densenet(variant, growth_rate, block_config, pretrained, **kwargs):
if kwargs.pop('features_only', False):
assert False, 'Not Implemented' # TODO
load_strict = False
@ -285,8 +278,7 @@ def _densenet(variant, growth_rate, block_config, num_init_features, pretrained,
load_strict = True
model_class = DenseNet
default_cfg = default_cfgs[variant]
model = model_class(
growth_rate=growth_rate, block_config=block_config, num_init_features=num_init_features, **kwargs)
model = model_class(growth_rate=growth_rate, block_config=block_config, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(
@ -304,8 +296,7 @@ def densenet121(pretrained=False, **kwargs):
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _densenet(
'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
pretrained=pretrained, **kwargs)
'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs)
return model
@ -315,8 +306,8 @@ def densenet121d(pretrained=False, **kwargs):
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _densenet(
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
stem_type='deep', pretrained=pretrained, **kwargs)
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
pretrained=pretrained, **kwargs)
return model
@ -326,8 +317,42 @@ def densenet121tn(pretrained=False, **kwargs):
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _densenet(
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
stem_type='deep_tiered_narrow', pretrained=pretrained, **kwargs)
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep_tiered_narrow',
pretrained=pretrained, **kwargs)
return model
@register_model
def densenet121d_evob(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _densenet(
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=EvoNormBatch2d, pretrained=pretrained, **kwargs)
return model
@register_model
def densenet121d_evos(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _densenet(
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=EvoNormSample2d, pretrained=pretrained, **kwargs)
return model
@register_model
def densenet121d_iabn(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
from inplace_abn import InPlaceABN
model = _densenet(
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=InPlaceABN, pretrained=pretrained, **kwargs)
return model
@ -337,8 +362,7 @@ def densenet169(pretrained=False, **kwargs):
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _densenet(
'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), num_init_features=64,
pretrained=pretrained, **kwargs)
'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), pretrained=pretrained, **kwargs)
return model
@ -348,17 +372,25 @@ def densenet201(pretrained=False, **kwargs):
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _densenet(
'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), num_init_features=64,
pretrained=pretrained, **kwargs)
'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), pretrained=pretrained, **kwargs)
return model
@register_model
def densenet161(pretrained=False, **kwargs):
r"""Densenet-201 model from
r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _densenet(
'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), num_init_features=96,
pretrained=pretrained, **kwargs)
'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), pretrained=pretrained, **kwargs)
return model
@register_model
def densenet264(pretrained=False, **kwargs):
r"""Densenet-264 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _densenet(
'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs)
return model

@ -19,3 +19,5 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .anti_aliasing import AntiAliasDownsampleLayer
from .space_to_depth import SpaceToDepthModule
from .blur_pool import BlurPool2d
from .norm_act import BatchNormAct2d
from .evo_norm import EvoNormBatch2d, EvoNormSample2d

@ -0,0 +1,134 @@
"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch
An attempt at getting decent performing EvoNorms running in PyTorch.
While currently faster than other impl, still quite a ways off the built-in BN
in terms of memory usage and throughput.
Still very much a WIP, fiddling with buffer usage, in-place optimizations, and layouts.
Hacked together by Ross Wightman
"""
import torch
import torch.nn as nn
@torch.jit.script
def evo_batch_jit(
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, running_var: torch.Tensor,
momentum: float, training: bool, nonlin: bool, eps: float):
x_type = x.dtype
running_var = running_var.detach() # FIXME why is this needed, it's a buffer?
if training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) # FIXME biased, unbiased?
running_var.copy_(momentum * var + (1 - momentum) * running_var)
else:
var = running_var.clone()
if nonlin:
# FIXME biased, unbiased?
d = (x * v.to(x_type)) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(eps).sqrt_().to(dtype=x_type)
d = d.max(var.add(eps).sqrt_().to(dtype=x_type))
x = x / d
return x.mul_(weight).add_(bias)
else:
return x.mul(weight).add_(bias)
class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5, jit=True):
super(EvoNormBatch2d, self).__init__()
self.momentum = momentum
self.nonlin = nonlin
self.eps = eps
self.jit = jit
param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
if nonlin:
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.nonlin:
nn.init.ones_(self.v)
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
if self.jit:
return evo_batch_jit(
x, self.v, self.weight, self.bias, self.running_var, self.momentum,
self.training, self.nonlin, self.eps)
else:
x_type = x.dtype
if self.training:
var = x.var(dim=(0, 2, 3), keepdim=True)
self.running_var.copy_(self.momentum * var + (1 - self.momentum) * self.running_var)
else:
var = self.running_var.clone()
if self.nonlin:
v = self.v.to(dtype=x_type)
d = (x * v) + x.var(dim=(2, 3), keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
d = d.max(var.add(self.eps).sqrt_().to(dtype=x_type))
x = x / d
return x.mul_(self.weight).add_(self.bias)
else:
return x.mul(self.weight).add_(self.bias)
@torch.jit.script
def evo_sample_jit(
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
groups: int, nonlin: bool, eps: float):
B, C, H, W = x.shape
assert C % groups == 0
if nonlin:
n = (x * v).sigmoid_().reshape(B, groups, -1)
x = x.reshape(B, groups, -1)
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(eps).sqrt_()
x = x.reshape(B, C, H, W)
return x.mul_(weight).add_(bias)
class EvoNormSample2d(nn.Module):
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5, jit=True):
super(EvoNormSample2d, self).__init__()
self.nonlin = nonlin
self.groups = groups
self.eps = eps
self.jit = jit
param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
if nonlin:
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.nonlin:
nn.init.ones_(self.v)
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
if self.jit:
return evo_sample_jit(
x, self.v, self.weight, self.bias, self.groups, self.nonlin, self.eps)
else:
B, C, H, W = x.shape
assert C % self.groups == 0
if self.nonlin:
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
x = x.reshape(B, self.groups, -1)
x = n / (x.std(dim=-1, unbiased=False, keepdim=True) + self.eps)
x = x.reshape(B, C, H, W)
return x.mul_(self.weight).add_(self.bias)
else:
return x.mul(self.weight).add_(self.bias)

@ -0,0 +1,50 @@
""" Normalization + Activation Layers
"""
from torch import nn as nn
from torch.nn import functional as F
class BatchNormAct2d(nn.BatchNorm2d):
"""BatchNorm + Activation
This module performs BatchNorm + Actibation in s manner that will remain bavkwards
compatible with weights trained with separate bn, act. This is why we inherit from BN
instead of composing it as a .bn member.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, act_layer=nn.ReLU, inplace=True):
super(BatchNormAct2d, self).__init__(num_features, eps, momentum, affine, track_running_stats)
self.act = act_layer(inplace=inplace)
def forward(self, x):
# FIXME cannot call parent forward() and maintain jit.script compatibility?
# x = super(BatchNormAct2d, self).forward(x)
# BEGIN nn.BatchNorm2d forward() cut & paste
# self._check_input_dim(x)
# exponential_average_factor is self.momentum set to
# (when it is available) only so that if gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
x = F.batch_norm(
x, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
# END BatchNorm2d forward()
x = self.act(x)
return x
Loading…
Cancel
Save