You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/layers/filter_response_norm.py

69 lines
2.5 KiB

""" Filter Response Norm in PyTorch
Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
import torch.nn as nn
from .create_act import create_act_layer
from .trace_utils import _assert
def inv_instance_rms(x, eps: float = 1e-5):
rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)
return rms.expand(x.shape)
class FilterResponseNormTlu2d(nn.Module):
def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_):
super(FilterResponseNormTlu2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.rms = rms
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.tau is not None:
nn.init.zeros_(self.tau)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = x * inv_instance_rms(x, self.eps)
x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x
class FilterResponseNormAct2d(nn.Module):
def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_):
super(FilterResponseNormAct2d, self).__init__()
if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer, inplace=inplace)
else:
self.act = nn.Identity()
self.rms = rms
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = x * inv_instance_rms(x, self.eps)
x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return self.act(x)