@ -18,7 +18,7 @@ import torch.nn.functional as F
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from . helpers import build_model_with_cfg
from . registry import register_model
from . layers import ClassifierHead , DropPath , AvgPool2dSame , ScaledStdConv2d , get_act_layer , get_attn , make_divisible
from . layers import ClassifierHead , DropPath , AvgPool2dSame , ScaledStdConv2d , get_act_layer , get_attn , make_divisible , get_act_fn
def _dcfg ( url = ' ' , * * kwargs ) :
@ -40,17 +40,17 @@ default_cfgs = {
' nf_regnet_b4 ' : _dcfg ( url = ' ' , input_size = ( 3 , 320 , 320 ) ) ,
' nf_regnet_b5 ' : _dcfg ( url = ' ' , input_size = ( 3 , 384 , 384 ) ) ,
' nf_resnet26 d ' : _dcfg ( url = ' ' , first_conv = ' stem.conv 1 ' ) ,
' nf_resnet50 d ' : _dcfg ( url = ' ' , first_conv = ' stem.conv 1 ' ) ,
' nf_resnet101 d ' : _dcfg ( url = ' ' , first_conv = ' stem.conv 1 ' ) ,
' nf_resnet26 ' : _dcfg ( url = ' ' , first_conv = ' stem.conv ' ) ,
' nf_resnet50 ' : _dcfg ( url = ' ' , first_conv = ' stem.conv ' ) ,
' nf_resnet101 ' : _dcfg ( url = ' ' , first_conv = ' stem.conv ' ) ,
' nf_seresnet26 d ' : _dcfg ( url = ' ' , first_conv = ' stem.conv 1 ' ) ,
' nf_seresnet50 d ' : _dcfg ( url = ' ' , first_conv = ' stem.conv 1 ' ) ,
' nf_seresnet101 d ' : _dcfg ( url = ' ' , first_conv = ' stem.conv 1 ' ) ,
' nf_seresnet26 ' : _dcfg ( url = ' ' , first_conv = ' stem.conv ' ) ,
' nf_seresnet50 ' : _dcfg ( url = ' ' , first_conv = ' stem.conv ' ) ,
' nf_seresnet101 ' : _dcfg ( url = ' ' , first_conv = ' stem.conv ' ) ,
' nf_ecaresnet26 d ' : _dcfg ( url = ' ' , first_conv = ' stem.conv 1 ' ) ,
' nf_ecaresnet50 d ' : _dcfg ( url = ' ' , first_conv = ' stem.conv 1 ' ) ,
' nf_ecaresnet101 d ' : _dcfg ( url = ' ' , first_conv = ' stem.conv 1 ' ) ,
' nf_ecaresnet26 ' : _dcfg ( url = ' ' , first_conv = ' stem.conv ' ) ,
' nf_ecaresnet50 ' : _dcfg ( url = ' ' , first_conv = ' stem.conv ' ) ,
' nf_ecaresnet101 ' : _dcfg ( url = ' ' , first_conv = ' stem.conv ' ) ,
}
@ -59,6 +59,7 @@ class NfCfg:
depths : Tuple [ int , int , int , int ]
channels : Tuple [ int , int , int , int ]
alpha : float = 0.2
gamma_in_act : bool = False
stem_type : str = ' 3x3 '
stem_chs : Optional [ int ] = None
group_size : Optional [ int ] = 8
@ -84,68 +85,65 @@ model_cfgs = dict(
nf_regnet_b5 = NfCfg ( depths = ( 3 , 7 , 14 , 14 ) , channels = ( 80 , 168 , 336 , 704 ) , num_features = 2048 ) ,
# ResNet (preact, D style deep stem/avg down) defs
nf_resnet26 d = NfCfg (
nf_resnet26 = NfCfg (
depths = ( 2 , 2 , 2 , 2 ) , channels = ( 256 , 512 , 1024 , 2048 ) ,
stem_type = ' deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
stem_type = ' 7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
act_layer = ' relu ' , attn_layer = None , ) ,
nf_resnet50 d = NfCfg (
nf_resnet50 = NfCfg (
depths = ( 3 , 4 , 6 , 3 ) , channels = ( 256 , 512 , 1024 , 2048 ) ,
stem_type = ' deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
stem_type = ' 7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
act_layer = ' relu ' , attn_layer = None ) ,
nf_resnet101 d = NfCfg (
nf_resnet101 = NfCfg (
depths = ( 3 , 4 , 6 , 3 ) , channels = ( 256 , 512 , 1024 , 2048 ) ,
stem_type = ' deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
stem_type = ' 7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
act_layer = ' relu ' , attn_layer = None ) ,
nf_seresnet26 d = NfCfg (
nf_seresnet26 = NfCfg (
depths = ( 2 , 2 , 2 , 2 ) , channels = ( 256 , 512 , 1024 , 2048 ) ,
stem_type = ' deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
stem_type = ' 7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
act_layer = ' relu ' , attn_layer = ' se ' , attn_kwargs = dict ( reduction_ratio = 0.25 ) ) ,
nf_seresnet50 d = NfCfg (
nf_seresnet50 = NfCfg (
depths = ( 3 , 4 , 6 , 3 ) , channels = ( 256 , 512 , 1024 , 2048 ) ,
stem_type = ' deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
stem_type = ' 7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
act_layer = ' relu ' , attn_layer = ' se ' , attn_kwargs = dict ( reduction_ratio = 0.25 ) ) ,
nf_seresnet101 d = NfCfg (
nf_seresnet101 = NfCfg (
depths = ( 3 , 4 , 6 , 3 ) , channels = ( 256 , 512 , 1024 , 2048 ) ,
stem_type = ' deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
stem_type = ' 7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
act_layer = ' relu ' , attn_layer = ' se ' , attn_kwargs = dict ( reduction_ratio = 0.25 ) ) ,
nf_ecaresnet26 d = NfCfg (
nf_ecaresnet26 = NfCfg (
depths = ( 2 , 2 , 2 , 2 ) , channels = ( 256 , 512 , 1024 , 2048 ) ,
stem_type = ' deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
stem_type = ' 7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
act_layer = ' relu ' , attn_layer = ' eca ' , attn_kwargs = dict ( ) ) ,
nf_ecaresnet50 d = NfCfg (
nf_ecaresnet50 = NfCfg (
depths = ( 3 , 4 , 6 , 3 ) , channels = ( 256 , 512 , 1024 , 2048 ) ,
stem_type = ' deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
stem_type = ' 7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
act_layer = ' relu ' , attn_layer = ' eca ' , attn_kwargs = dict ( ) ) ,
nf_ecaresnet101 d = NfCfg (
nf_ecaresnet101 = NfCfg (
depths = ( 3 , 4 , 6 , 3 ) , channels = ( 256 , 512 , 1024 , 2048 ) ,
stem_type = ' deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
stem_type = ' 7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
act_layer = ' relu ' , attn_layer = ' eca ' , attn_kwargs = dict ( ) ) ,
)
# class NormFreeSiLU(nn.Module):
# _K = 1. / 0.5595
# def __init__(self, inplace=False):
# super().__init__()
# self.inplace = inplace
#
# def forward(self, x):
# return F.silu(x, inplace=self.inplace) * self._K
#
#
# class NormFreeReLU(nn.Module):
# _K = (0.5 * (1. - 1. / math.pi)) ** -0.5
#
# def __init__(self, inplace=False):
# super().__init__()
# self.inplace = inplace
#
# def forward(self, x):
# return F.relu(x, inplace=self.inplace) * self._K
class GammaAct ( nn . Module ) :
def __init__ ( self , act_type = ' relu ' , gamma : float = 1.0 , inplace = False ) :
super ( ) . __init__ ( )
self . act_fn = get_act_fn ( act_type )
self . gamma = gamma
self . inplace = inplace
def forward ( self , x ) :
return self . gamma * self . act_fn ( x , inplace = self . inplace )
def act_with_gamma ( act_type , gamma : float = 1. ) :
def _create ( inplace = False ) :
return GammaAct ( act_type , gamma = gamma , inplace = inplace )
return _create
class DownsampleAvg ( nn . Module ) :
@ -178,10 +176,9 @@ class NormalizationFreeBlock(nn.Module):
out_chs = out_chs or in_chs
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
mid_chs = make_divisible ( in_chs * bottle_ratio if efficient else out_chs * bottle_ratio , ch_div )
groups = 1
if group_size is not None :
# NOTE: not correcting the mid_chs % group_size, fix model def if broken. I want % ch_div == 0 to stand.
groups = mid_chs / / group_size
groups = 1 if group_size is None else mid_chs / / group_size
if group_size and group_size % ch_div == 0 :
mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
self . alpha = alpha
self . beta = beta
self . attn_gain = attn_gain
@ -229,10 +226,11 @@ class NormalizationFreeBlock(nn.Module):
def create_stem ( in_chs , out_chs , stem_type = ' ' , conv_layer = None ) :
stem_stride = 2
stem = OrderedDict ( )
assert stem_type in ( ' ' , ' deep ' , ' 3x3 ' , ' 7x7 ' )
assert stem_type in ( ' ' , ' deep ' , ' 3x3 ' , ' 7x7 ' , ' deep_pool ' , ' 3x3_pool ' , ' 7x7_pool ' )
if ' deep ' in stem_type :
# 3 deep 3x3 conv stack as in ResNet V1D models
# 3 deep 3x3 conv stack as in ResNet V1D models . NOTE: doesn't work as well here
mid_chs = out_chs / / 2
stem [ ' conv1 ' ] = conv_layer ( in_chs , mid_chs , kernel_size = 3 , stride = 2 )
stem [ ' conv2 ' ] = conv_layer ( mid_chs , mid_chs , kernel_size = 3 , stride = 1 )
@ -244,12 +242,16 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
# 7x7 stem conv as in ResNet
stem [ ' conv ' ] = conv_layer ( in_chs , out_chs , kernel_size = 7 , stride = 2 )
return nn . Sequential ( stem )
if ' pool ' in stem_type :
stem [ ' pool ' ] = nn . MaxPool2d ( 3 , stride = 2 , padding = 1 )
stem_stride = 4
return nn . Sequential ( stem ) , stem_stride
_nonlin_gamma = dict (
silu = .5595 ,
relu = ( 0.5 * ( 1. - 1. / math . pi ) ) * * 0.5 ,
silu = 1./ .5595,
relu = ( 0.5 * ( 1. - 1. / math . pi ) ) * * - 0.5 ,
identity = 1.0
)
@ -264,9 +266,12 @@ class NormalizerFreeNet(nn.Module):
the ( preact ) ResNet models described earlier in the paper .
There are a few differences :
* channels are rounded to be divisible by 8 by default ( keep TC happy ) , this changes param counts
* channels are rounded to be divisible by 8 by default ( keep tensor core kernels happy ) ,
this changes channel dim and param counts slightly from the paper models
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
impact in PyTorch when done with the weight scaling there . This likely wasn ' t a concern in the JAX impl.
* a config option ` gamma_in_act ` can be enabled to not apply gamma in StdConv as described above , but
apply it in each activation . This is slightly slower , and yields slightly different results .
* skipinit is disabled by default , it seems to have a rather drastic impact on GPU memory use and throughput
for what it is / does . Approx 8 - 10 % throughput loss .
"""
@ -275,29 +280,33 @@ class NormalizerFreeNet(nn.Module):
super ( ) . __init__ ( )
self . num_classes = num_classes
self . drop_rate = drop_rate
act_layer = get_act_layer ( cfg . act_layer )
assert cfg . act_layer in _nonlin_gamma , f " Please add non-linearity constants for activation ( { cfg . act_layer } ). "
conv_layer = partial ( ScaledStdConv2d , bias = True , gain = True , gamma = _nonlin_gamma [ cfg . act_layer ] )
if cfg . gamma_in_act :
act_layer = act_with_gamma ( cfg . act_layer , gamma = _nonlin_gamma [ cfg . act_layer ] )
conv_layer = partial ( ScaledStdConv2d , bias = True , gain = True )
else :
act_layer = get_act_layer ( cfg . act_layer )
conv_layer = partial ( ScaledStdConv2d , bias = True , gain = True , gamma = _nonlin_gamma [ cfg . act_layer ] )
attn_layer = partial ( get_attn ( cfg . attn_layer ) , * * cfg . attn_kwargs ) if cfg . attn_layer else None
self . feature_info = [ ] # FIXME fill out feature info
stem_chs = cfg . stem_chs or cfg . channels [ 0 ]
stem_chs = make_divisible ( stem_chs * cfg . width_factor , cfg . ch_div )
self . stem = create_stem ( in_chans , stem_chs , cfg . stem_type , conv_layer = conv_layer )
self . stem , stem_stride = create_stem ( in_chans , stem_chs , cfg . stem_type , conv_layer = conv_layer )
prev_chs = stem_chs
self . feature_info = [ ] # NOTE: there will be no stride == 2 feature if stem_stride == 4
dpr = [ x . tolist ( ) for x in torch . linspace ( 0 , drop_path_rate , sum ( cfg . depths ) ) . split ( cfg . depths ) ]
net_stride = 2
prev_chs = stem_chs
net_stride = stem_stride
dilation = 1
expected_var = 1.0
stages = [ ]
for stage_idx , stage_depth in enumerate ( cfg . depths ) :
if net_stride > = output_stride :
dilation * = 2
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
self . feature_info + = [ dict (
num_chs = prev_chs , reduction = net_stride , module = f ' stages. { stage_idx } .0.act1 ' if stride == 2 else ' ' ) ]
if net_stride > = output_stride and stride > 1 :
dilation * = stride
stride = 1
else :
stride = 2
net_stride * = stride
first_dilation = 1 if dilation in ( 1 , 2 ) else 2
@ -338,7 +347,10 @@ class NormalizerFreeNet(nn.Module):
else :
self . num_features = prev_chs
self . final_conv = nn . Identity ( )
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv
self . final_act = act_layer ( )
self . feature_info + = [ dict ( num_chs = self . num_features , reduction = net_stride , module = ' final_act ' ) ]
self . head = ClassifierHead ( self . num_features , num_classes , pool_type = global_pool , drop_rate = self . drop_rate )
for n , m in self . named_modules ( ) :
@ -373,11 +385,14 @@ class NormalizerFreeNet(nn.Module):
def _create_normfreenet ( variant , pretrained = False , * * kwargs ) :
model_cfg = model_cfgs [ variant ]
feature_cfg = dict ( flatten_sequential = True )
feature_cfg [ ' feature_cls ' ] = ' hook ' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
if ' pool ' in model_cfg . stem_type :
feature_cfg [ ' out_indices ' ] = ( 1 , 2 , 3 , 4 ) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet
return build_model_with_cfg (
NormalizerFreeNet , variant , pretrained , model_cfg = model_cfg s[ variant ] , default_cfg = default_cfgs [ variant ] ,
NormalizerFreeNet , variant , pretrained , model_cfg = model_cfg , default_cfg = default_cfgs [ variant ] ,
feature_cfg = feature_cfg , * * kwargs )
@ -412,30 +427,30 @@ def nf_regnet_b5(pretrained=False, **kwargs):
@register_model
def nf_resnet26 d ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_resnet26 d ' , pretrained = pretrained , * * kwargs )
def nf_resnet26 ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_resnet26 ' , pretrained = pretrained , * * kwargs )
@register_model
def nf_resnet50 d ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_resnet50 d ' , pretrained = pretrained , * * kwargs )
def nf_resnet50 ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_resnet50 ' , pretrained = pretrained , * * kwargs )
@register_model
def nf_seresnet26 d ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_seresnet26 d ' , pretrained = pretrained , * * kwargs )
def nf_seresnet26 ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_seresnet26 ' , pretrained = pretrained , * * kwargs )
@register_model
def nf_seresnet50 d ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_seresnet50 d ' , pretrained = pretrained , * * kwargs )
def nf_seresnet50 ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_seresnet50 ' , pretrained = pretrained , * * kwargs )
@register_model
def nf_ecaresnet26 d ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_ecaresnet26 d ' , pretrained = pretrained , * * kwargs )
def nf_ecaresnet26 ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_ecaresnet26 ' , pretrained = pretrained , * * kwargs )
@register_model
def nf_ecaresnet50 d ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_ecaresnet50 d ' , pretrained = pretrained , * * kwargs )
def nf_ecaresnet50 ( pretrained = False , * * kwargs ) :
return _create_normfreenet ( ' nf_ecaresnet50 ' , pretrained = pretrained , * * kwargs )