You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
69 lines
2.5 KiB
69 lines
2.5 KiB
""" Filter Response Norm in PyTorch
|
|
|
|
Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737
|
|
|
|
Hacked together by / Copyright 2021 Ross Wightman
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .create_act import create_act_layer
|
|
from .trace_utils import _assert
|
|
|
|
|
|
def inv_instance_rms(x, eps: float = 1e-5):
|
|
rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)
|
|
return rms.expand(x.shape)
|
|
|
|
|
|
class FilterResponseNormTlu2d(nn.Module):
|
|
def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_):
|
|
super(FilterResponseNormTlu2d, self).__init__()
|
|
self.apply_act = apply_act # apply activation (non-linearity)
|
|
self.rms = rms
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.ones(num_features))
|
|
self.bias = nn.Parameter(torch.zeros(num_features))
|
|
self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.init.ones_(self.weight)
|
|
nn.init.zeros_(self.bias)
|
|
if self.tau is not None:
|
|
nn.init.zeros_(self.tau)
|
|
|
|
def forward(self, x):
|
|
_assert(x.dim() == 4, 'expected 4D input')
|
|
x_dtype = x.dtype
|
|
v_shape = (1, -1, 1, 1)
|
|
x = x * inv_instance_rms(x, self.eps)
|
|
x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
|
return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x
|
|
|
|
|
|
class FilterResponseNormAct2d(nn.Module):
|
|
def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_):
|
|
super(FilterResponseNormAct2d, self).__init__()
|
|
if act_layer is not None and apply_act:
|
|
self.act = create_act_layer(act_layer, inplace=inplace)
|
|
else:
|
|
self.act = nn.Identity()
|
|
self.rms = rms
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.ones(num_features))
|
|
self.bias = nn.Parameter(torch.zeros(num_features))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.init.ones_(self.weight)
|
|
nn.init.zeros_(self.bias)
|
|
|
|
def forward(self, x):
|
|
_assert(x.dim() == 4, 'expected 4D input')
|
|
x_dtype = x.dtype
|
|
v_shape = (1, -1, 1, 1)
|
|
x = x * inv_instance_rms(x, self.eps)
|
|
x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
|
return self.act(x)
|