Add norm_act factory method, move JIT of norm layers to factory

pull/155/head
Ross Wightman 4 years ago
parent 14edacdf9a
commit 780860d140

@ -4,6 +4,7 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool.
""" """
import re import re
from collections import OrderedDict from collections import OrderedDict
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -13,7 +14,7 @@ from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act
from .registry import register_model from .registry import register_model
__all__ = ['DenseNet'] __all__ = ['DenseNet']
@ -327,9 +328,11 @@ def densenet121d_evob(pretrained=False, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
def norm_act_fn(num_features, **kwargs):
return create_norm_act('EvoNormBatch', num_features, jit=True, **kwargs)
model = _densenet( model = _densenet(
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=EvoNormBatch2d, pretrained=pretrained, **kwargs) norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
return model return model
@ -338,9 +341,11 @@ def densenet121d_evos(pretrained=False, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
def norm_act_fn(num_features, **kwargs):
return create_norm_act('EvoNormSample', num_features, jit=True, **kwargs)
model = _densenet( model = _densenet(
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=EvoNormSample2d, pretrained=pretrained, **kwargs) norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
return model return model
@ -349,10 +354,11 @@ def densenet121d_iabn(pretrained=False, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
from inplace_abn import InPlaceABN def norm_act_fn(num_features, **kwargs):
return create_norm_act('iabn', num_features, **kwargs)
model = _densenet( model = _densenet(
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', 'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=InPlaceABN, pretrained=pretrained, **kwargs) norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
return model return model

@ -20,4 +20,5 @@ from .anti_aliasing import AntiAliasDownsampleLayer
from .space_to_depth import SpaceToDepthModule from .space_to_depth import SpaceToDepthModule
from .blur_pool import BlurPool2d from .blur_pool import BlurPool2d
from .norm_act import BatchNormAct2d from .norm_act import BatchNormAct2d
from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .create_norm_act import create_norm_act

@ -0,0 +1,37 @@
import torch
import torch.nn as nn
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .norm_act import BatchNormAct2d
try:
from inplace_abn import InPlaceABN
has_iabn = True
except ImportError:
has_iabn = False
def create_norm_act(layer_type, num_features, jit=False, **kwargs):
layer_parts = layer_type.split('_')
assert len(layer_parts) in (1, 2)
layer_class = layer_parts[0].lower()
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection
if layer_class == "batchnormact":
layer = BatchNormAct2d(num_features, **kwargs) # defaults to RELU of no kwargs override
elif layer_class == "batchnormrelu":
assert 'act_layer' not in kwargs
layer = BatchNormAct2d(num_features, act_layer=nn.ReLU, **kwargs)
elif layer_class == "evonormbatch":
layer = EvoNormBatch2d(num_features, **kwargs)
elif layer_class == "evonormsample":
layer = EvoNormSample2d(num_features, **kwargs)
elif layer_class == "iabn" or layer_class == "inplaceabn":
if not has_iabn:
raise ImportError(
"Pplease install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
layer = InPlaceABN(num_features, **kwargs)
else:
assert False, "Invalid norm_act layer (%s)" % layer_class
if jit:
layer = torch.jit.script(layer)
return layer

@ -13,35 +13,12 @@ import torch
import torch.nn as nn import torch.nn as nn
@torch.jit.script
def evo_batch_jit(
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, running_var: torch.Tensor,
momentum: float, training: bool, nonlin: bool, eps: float):
x_type = x.dtype
running_var = running_var.detach() # FIXME why is this needed, it's a buffer?
if training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) # FIXME biased, unbiased?
running_var.copy_(momentum * var + (1 - momentum) * running_var)
else:
var = running_var.clone()
if nonlin:
# FIXME biased, unbiased?
d = (x * v.to(x_type)) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(eps).sqrt_().to(dtype=x_type)
d = d.max(var.add(eps).sqrt_().to(dtype=x_type))
x = x / d
return x.mul_(weight).add_(bias)
else:
return x.mul(weight).add_(bias)
class EvoNormBatch2d(nn.Module): class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5, jit=True): def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5):
super(EvoNormBatch2d, self).__init__() super(EvoNormBatch2d, self).__init__()
self.momentum = momentum self.momentum = momentum
self.nonlin = nonlin self.nonlin = nonlin
self.eps = eps self.eps = eps
self.jit = jit
param_shape = (1, num_features, 1, 1) param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
@ -58,50 +35,29 @@ class EvoNormBatch2d(nn.Module):
def forward(self, x): def forward(self, x):
assert x.dim() == 4, 'expected 4D input' assert x.dim() == 4, 'expected 4D input'
x_type = x.dtype
if self.jit: if self.training:
return evo_batch_jit( var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
x, self.v, self.weight, self.bias, self.running_var, self.momentum, self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var)
self.training, self.nonlin, self.eps)
else: else:
x_type = x.dtype var = self.running_var.clone()
if self.training:
var = x.var(dim=(0, 2, 3), keepdim=True)
self.running_var.copy_(self.momentum * var + (1 - self.momentum) * self.running_var)
else:
var = self.running_var.clone()
if self.nonlin:
v = self.v.to(dtype=x_type)
d = (x * v) + x.var(dim=(2, 3), keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
d = d.max(var.add(self.eps).sqrt_().to(dtype=x_type))
x = x / d
return x.mul_(self.weight).add_(self.bias)
else:
return x.mul(self.weight).add_(self.bias)
if self.nonlin:
@torch.jit.script v = self.v.to(dtype=x_type)
def evo_sample_jit( d = (x * v) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, d = d.max(var.add_(self.eps).sqrt_().to(dtype=x_type))
groups: int, nonlin: bool, eps: float): x = x / d
B, C, H, W = x.shape return x.mul_(self.weight).add_(self.bias)
assert C % groups == 0 else:
if nonlin: return x.mul(self.weight).add_(self.bias)
n = (x * v).sigmoid_().reshape(B, groups, -1)
x = x.reshape(B, groups, -1)
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(eps).sqrt_()
x = x.reshape(B, C, H, W)
return x.mul_(weight).add_(bias)
class EvoNormSample2d(nn.Module): class EvoNormSample2d(nn.Module):
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5, jit=True): def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5):
super(EvoNormSample2d, self).__init__() super(EvoNormSample2d, self).__init__()
self.nonlin = nonlin self.nonlin = nonlin
self.groups = groups self.groups = groups
self.eps = eps self.eps = eps
self.jit = jit
param_shape = (1, num_features, 1, 1) param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
@ -117,18 +73,13 @@ class EvoNormSample2d(nn.Module):
def forward(self, x): def forward(self, x):
assert x.dim() == 4, 'expected 4D input' assert x.dim() == 4, 'expected 4D input'
B, C, H, W = x.shape
if self.jit: assert C % self.groups == 0
return evo_sample_jit( if self.nonlin:
x, self.v, self.weight, self.bias, self.groups, self.nonlin, self.eps) n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
x = x.reshape(B, self.groups, -1)
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(self.eps).sqrt_()
x = x.reshape(B, C, H, W)
return x.mul_(self.weight).add_(self.bias)
else: else:
B, C, H, W = x.shape return x.mul(self.weight).add_(self.bias)
assert C % self.groups == 0
if self.nonlin:
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
x = x.reshape(B, self.groups, -1)
x = n / (x.std(dim=-1, unbiased=False, keepdim=True) + self.eps)
x = x.reshape(B, C, H, W)
return x.mul_(self.weight).add_(self.bias)
else:
return x.mul(self.weight).add_(self.bias)

Loading…
Cancel
Save