Add norm_act factory method, move JIT of norm layers to factory

pull/155/head
Ross Wightman 4 years ago
parent 14edacdf9a
commit 780860d140

@ -4,6 +4,7 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool.
"""
import re
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
@ -13,7 +14,7 @@ from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act
from .registry import register_model
__all__ = ['DenseNet']
@ -327,9 +328,11 @@ def densenet121d_evob(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
def norm_act_fn(num_features, **kwargs):
return create_norm_act('EvoNormBatch', num_features, jit=True, **kwargs)
model = _densenet(
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=EvoNormBatch2d, pretrained=pretrained, **kwargs)
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
return model
@ -338,9 +341,11 @@ def densenet121d_evos(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
def norm_act_fn(num_features, **kwargs):
return create_norm_act('EvoNormSample', num_features, jit=True, **kwargs)
model = _densenet(
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=EvoNormSample2d, pretrained=pretrained, **kwargs)
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
return model
@ -349,10 +354,11 @@ def densenet121d_iabn(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
from inplace_abn import InPlaceABN
def norm_act_fn(num_features, **kwargs):
return create_norm_act('iabn', num_features, **kwargs)
model = _densenet(
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=InPlaceABN, pretrained=pretrained, **kwargs)
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
return model

@ -20,4 +20,5 @@ from .anti_aliasing import AntiAliasDownsampleLayer
from .space_to_depth import SpaceToDepthModule
from .blur_pool import BlurPool2d
from .norm_act import BatchNormAct2d
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .create_norm_act import create_norm_act

@ -0,0 +1,37 @@
import torch
import torch.nn as nn
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .norm_act import BatchNormAct2d
try:
from inplace_abn import InPlaceABN
has_iabn = True
except ImportError:
has_iabn = False
def create_norm_act(layer_type, num_features, jit=False, **kwargs):
layer_parts = layer_type.split('_')
assert len(layer_parts) in (1, 2)
layer_class = layer_parts[0].lower()
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection
if layer_class == "batchnormact":
layer = BatchNormAct2d(num_features, **kwargs) # defaults to RELU of no kwargs override
elif layer_class == "batchnormrelu":
assert 'act_layer' not in kwargs
layer = BatchNormAct2d(num_features, act_layer=nn.ReLU, **kwargs)
elif layer_class == "evonormbatch":
layer = EvoNormBatch2d(num_features, **kwargs)
elif layer_class == "evonormsample":
layer = EvoNormSample2d(num_features, **kwargs)
elif layer_class == "iabn" or layer_class == "inplaceabn":
if not has_iabn:
raise ImportError(
"Pplease install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
layer = InPlaceABN(num_features, **kwargs)
else:
assert False, "Invalid norm_act layer (%s)" % layer_class
if jit:
layer = torch.jit.script(layer)
return layer

@ -13,35 +13,12 @@ import torch
import torch.nn as nn
@torch.jit.script
def evo_batch_jit(
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, running_var: torch.Tensor,
momentum: float, training: bool, nonlin: bool, eps: float):
x_type = x.dtype
running_var = running_var.detach() # FIXME why is this needed, it's a buffer?
if training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) # FIXME biased, unbiased?
running_var.copy_(momentum * var + (1 - momentum) * running_var)
else:
var = running_var.clone()
if nonlin:
# FIXME biased, unbiased?
d = (x * v.to(x_type)) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(eps).sqrt_().to(dtype=x_type)
d = d.max(var.add(eps).sqrt_().to(dtype=x_type))
x = x / d
return x.mul_(weight).add_(bias)
else:
return x.mul(weight).add_(bias)
class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5, jit=True):
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5):
super(EvoNormBatch2d, self).__init__()
self.momentum = momentum
self.nonlin = nonlin
self.eps = eps
self.jit = jit
param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
@ -58,50 +35,29 @@ class EvoNormBatch2d(nn.Module):
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
if self.jit:
return evo_batch_jit(
x, self.v, self.weight, self.bias, self.running_var, self.momentum,
self.training, self.nonlin, self.eps)
x_type = x.dtype
if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var)
else:
x_type = x.dtype
if self.training:
var = x.var(dim=(0, 2, 3), keepdim=True)
self.running_var.copy_(self.momentum * var + (1 - self.momentum) * self.running_var)
else:
var = self.running_var.clone()
if self.nonlin:
v = self.v.to(dtype=x_type)
d = (x * v) + x.var(dim=(2, 3), keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
d = d.max(var.add(self.eps).sqrt_().to(dtype=x_type))
x = x / d
return x.mul_(self.weight).add_(self.bias)
else:
return x.mul(self.weight).add_(self.bias)
var = self.running_var.clone()
@torch.jit.script
def evo_sample_jit(
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
groups: int, nonlin: bool, eps: float):
B, C, H, W = x.shape
assert C % groups == 0
if nonlin:
n = (x * v).sigmoid_().reshape(B, groups, -1)
x = x.reshape(B, groups, -1)
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(eps).sqrt_()
x = x.reshape(B, C, H, W)
return x.mul_(weight).add_(bias)
if self.nonlin:
v = self.v.to(dtype=x_type)
d = (x * v) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
d = d.max(var.add_(self.eps).sqrt_().to(dtype=x_type))
x = x / d
return x.mul_(self.weight).add_(self.bias)
else:
return x.mul(self.weight).add_(self.bias)
class EvoNormSample2d(nn.Module):
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5, jit=True):
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5):
super(EvoNormSample2d, self).__init__()
self.nonlin = nonlin
self.groups = groups
self.eps = eps
self.jit = jit
param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
@ -117,18 +73,13 @@ class EvoNormSample2d(nn.Module):
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
if self.jit:
return evo_sample_jit(
x, self.v, self.weight, self.bias, self.groups, self.nonlin, self.eps)
B, C, H, W = x.shape
assert C % self.groups == 0
if self.nonlin:
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
x = x.reshape(B, self.groups, -1)
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(self.eps).sqrt_()
x = x.reshape(B, C, H, W)
return x.mul_(self.weight).add_(self.bias)
else:
B, C, H, W = x.shape
assert C % self.groups == 0
if self.nonlin:
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
x = x.reshape(B, self.groups, -1)
x = n / (x.std(dim=-1, unbiased=False, keepdim=True) + self.eps)
x = x.reshape(B, C, H, W)
return x.mul_(self.weight).add_(self.bias)
else:
return x.mul(self.weight).add_(self.bias)
return x.mul(self.weight).add_(self.bias)

Loading…
Cancel
Save