@ -15,45 +15,76 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import numpy as np
import torch . nn as nn
from dataclasses import dataclass
from functools import partial
from typing import Optional , Union , Callable
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from . helpers import build_model_with_cfg
from . layers import ClassifierHead , AvgPool2dSame , Conv Bn Act, SEModule , DropPath
from . helpers import build_model_with_cfg , named_apply
from . layers import ClassifierHead , AvgPool2dSame , Conv Norm Act, SEModule , DropPath , get_act_layer , GroupNormAct
from . registry import register_model
def _mcfg ( * * kwargs ) :
cfg = dict ( se_ratio = 0. , bottle_ratio = 1. , stem_width = 32 )
cfg . update ( * * kwargs )
return cfg
@dataclass
class RegNetCfg :
depth : int = 21
w0 : int = 80
wa : float = 42.63
wm : float = 2.66
group_size : int = 24
bottle_ratio : float = 1.
se_ratio : float = 0.
stem_width : int = 32
downsample : Optional [ str ] = ' conv1x1 '
linear_out : bool = False
act_layer : Union [ str , Callable ] = ' relu '
norm_layer : Union [ str , Callable ] = ' batchnorm '
# Model FLOPS = three trailing digits * 10^8
model_cfgs = dict (
regnetx_002 = _mcfg ( w0 = 24 , wa = 36.44 , wm = 2.49 , group_w = 8 , depth = 13 ) ,
regnetx_004 = _mcfg ( w0 = 24 , wa = 24.48 , wm = 2.54 , group_w = 16 , depth = 22 ) ,
regnetx_006 = _mcfg ( w0 = 48 , wa = 36.97 , wm = 2.24 , group_w = 24 , depth = 16 ) ,
regnetx_008 = _mcfg ( w0 = 56 , wa = 35.73 , wm = 2.28 , group_w = 16 , depth = 16 ) ,
regnetx_016 = _mcfg ( w0 = 80 , wa = 34.01 , wm = 2.25 , group_w = 24 , depth = 18 ) ,
regnetx_032 = _mcfg ( w0 = 88 , wa = 26.31 , wm = 2.25 , group_w = 48 , depth = 25 ) ,
regnetx_040 = _mcfg ( w0 = 96 , wa = 38.65 , wm = 2.43 , group_w = 40 , depth = 23 ) ,
regnetx_064 = _mcfg ( w0 = 184 , wa = 60.83 , wm = 2.07 , group_w = 56 , depth = 17 ) ,
regnetx_080 = _mcfg ( w0 = 80 , wa = 49.56 , wm = 2.88 , group_w = 120 , depth = 23 ) ,
regnetx_120 = _mcfg ( w0 = 168 , wa = 73.36 , wm = 2.37 , group_w = 112 , depth = 19 ) ,
regnetx_160 = _mcfg ( w0 = 216 , wa = 55.59 , wm = 2.1 , group_w = 128 , depth = 22 ) ,
regnetx_320 = _mcfg ( w0 = 320 , wa = 69.86 , wm = 2.0 , group_w = 168 , depth = 23 ) ,
regnety_002 = _mcfg ( w0 = 24 , wa = 36.44 , wm = 2.49 , group_w = 8 , depth = 13 , se_ratio = 0.25 ) ,
regnety_004 = _mcfg ( w0 = 48 , wa = 27.89 , wm = 2.09 , group_w = 8 , depth = 16 , se_ratio = 0.25 ) ,
regnety_006 = _mcfg ( w0 = 48 , wa = 32.54 , wm = 2.32 , group_w = 16 , depth = 15 , se_ratio = 0.25 ) ,
regnety_008 = _mcfg ( w0 = 56 , wa = 38.84 , wm = 2.4 , group_w = 16 , depth = 14 , se_ratio = 0.25 ) ,
regnety_016 = _mcfg ( w0 = 48 , wa = 20.71 , wm = 2.65 , group_w = 24 , depth = 27 , se_ratio = 0.25 ) ,
regnety_032 = _mcfg ( w0 = 80 , wa = 42.63 , wm = 2.66 , group_w = 24 , depth = 21 , se_ratio = 0.25 ) ,
regnety_040 = _mcfg ( w0 = 96 , wa = 31.41 , wm = 2.24 , group_w = 64 , depth = 22 , se_ratio = 0.25 ) ,
regnety_064 = _mcfg ( w0 = 112 , wa = 33.22 , wm = 2.27 , group_w = 72 , depth = 25 , se_ratio = 0.25 ) ,
regnety_080 = _mcfg ( w0 = 192 , wa = 76.82 , wm = 2.19 , group_w = 56 , depth = 17 , se_ratio = 0.25 ) ,
regnety_120 = _mcfg ( w0 = 168 , wa = 73.36 , wm = 2.37 , group_w = 112 , depth = 19 , se_ratio = 0.25 ) ,
regnety_160 = _mcfg ( w0 = 200 , wa = 106.23 , wm = 2.48 , group_w = 112 , depth = 18 , se_ratio = 0.25 ) ,
regnety_320 = _mcfg ( w0 = 232 , wa = 115.89 , wm = 2.53 , group_w = 232 , depth = 20 , se_ratio = 0.25 ) ,
# RegNet-X
regnetx_002 = RegNetCfg ( w0 = 24 , wa = 36.44 , wm = 2.49 , group_size = 8 , depth = 13 ) ,
regnetx_004 = RegNetCfg ( w0 = 24 , wa = 24.48 , wm = 2.54 , group_size = 16 , depth = 22 ) ,
regnetx_006 = RegNetCfg ( w0 = 48 , wa = 36.97 , wm = 2.24 , group_size = 24 , depth = 16 ) ,
regnetx_008 = RegNetCfg ( w0 = 56 , wa = 35.73 , wm = 2.28 , group_size = 16 , depth = 16 ) ,
regnetx_016 = RegNetCfg ( w0 = 80 , wa = 34.01 , wm = 2.25 , group_size = 24 , depth = 18 ) ,
regnetx_032 = RegNetCfg ( w0 = 88 , wa = 26.31 , wm = 2.25 , group_size = 48 , depth = 25 ) ,
regnetx_040 = RegNetCfg ( w0 = 96 , wa = 38.65 , wm = 2.43 , group_size = 40 , depth = 23 ) ,
regnetx_064 = RegNetCfg ( w0 = 184 , wa = 60.83 , wm = 2.07 , group_size = 56 , depth = 17 ) ,
regnetx_080 = RegNetCfg ( w0 = 80 , wa = 49.56 , wm = 2.88 , group_size = 120 , depth = 23 ) ,
regnetx_120 = RegNetCfg ( w0 = 168 , wa = 73.36 , wm = 2.37 , group_size = 112 , depth = 19 ) ,
regnetx_160 = RegNetCfg ( w0 = 216 , wa = 55.59 , wm = 2.1 , group_size = 128 , depth = 22 ) ,
regnetx_320 = RegNetCfg ( w0 = 320 , wa = 69.86 , wm = 2.0 , group_size = 168 , depth = 23 ) ,
# RegNet-Y
regnety_002 = RegNetCfg ( w0 = 24 , wa = 36.44 , wm = 2.49 , group_size = 8 , depth = 13 , se_ratio = 0.25 ) ,
regnety_004 = RegNetCfg ( w0 = 48 , wa = 27.89 , wm = 2.09 , group_size = 8 , depth = 16 , se_ratio = 0.25 ) ,
regnety_006 = RegNetCfg ( w0 = 48 , wa = 32.54 , wm = 2.32 , group_size = 16 , depth = 15 , se_ratio = 0.25 ) ,
regnety_008 = RegNetCfg ( w0 = 56 , wa = 38.84 , wm = 2.4 , group_size = 16 , depth = 14 , se_ratio = 0.25 ) ,
regnety_016 = RegNetCfg ( w0 = 48 , wa = 20.71 , wm = 2.65 , group_size = 24 , depth = 27 , se_ratio = 0.25 ) ,
regnety_032 = RegNetCfg ( w0 = 80 , wa = 42.63 , wm = 2.66 , group_size = 24 , depth = 21 , se_ratio = 0.25 ) ,
regnety_040 = RegNetCfg ( w0 = 96 , wa = 31.41 , wm = 2.24 , group_size = 64 , depth = 22 , se_ratio = 0.25 ) ,
regnety_064 = RegNetCfg ( w0 = 112 , wa = 33.22 , wm = 2.27 , group_size = 72 , depth = 25 , se_ratio = 0.25 ) ,
regnety_080 = RegNetCfg ( w0 = 192 , wa = 76.82 , wm = 2.19 , group_size = 56 , depth = 17 , se_ratio = 0.25 ) ,
regnety_120 = RegNetCfg ( w0 = 168 , wa = 73.36 , wm = 2.37 , group_size = 112 , depth = 19 , se_ratio = 0.25 ) ,
regnety_160 = RegNetCfg ( w0 = 200 , wa = 106.23 , wm = 2.48 , group_size = 112 , depth = 18 , se_ratio = 0.25 ) ,
regnety_320 = RegNetCfg ( w0 = 232 , wa = 115.89 , wm = 2.53 , group_size = 232 , depth = 20 , se_ratio = 0.25 ) ,
# Experimental
regnety_040s_gn = RegNetCfg (
w0 = 96 , wa = 31.41 , wm = 2.24 , group_size = 64 , depth = 22 , se_ratio = 0.25 ,
act_layer = ' silu ' , norm_layer = partial ( GroupNormAct , group_size = 16 ) ) ,
# RegNet-Z (unverified)
regnetz_005 = RegNetCfg (
depth = 21 , w0 = 16 , wa = 10.7 , wm = 2.51 , group_size = 4 , bottle_ratio = 4.0 , se_ratio = 0.25 ,
downsample = None , linear_out = True , act_layer = ' silu ' ,
) ,
regnetz_040 = RegNetCfg (
depth = 28 , w0 = 48 , wa = 14.5 , wm = 2.226 , group_size = 8 , bottle_ratio = 4.0 , se_ratio = 0.25 ,
downsample = None , linear_out = True , act_layer = ' silu ' ,
) ,
)
@ -80,6 +111,7 @@ default_cfgs = dict(
regnetx_120 = _cfg ( url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth ' ) ,
regnetx_160 = _cfg ( url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth ' ) ,
regnetx_320 = _cfg ( url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth ' ) ,
regnety_002 = _cfg ( url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth ' ) ,
regnety_004 = _cfg ( url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth ' ) ,
regnety_006 = _cfg ( url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth ' ) ,
@ -96,6 +128,11 @@ default_cfgs = dict(
url = ' https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth ' , # from Facebook DeiT GitHub repository
crop_pct = 1.0 , test_input_size = ( 3 , 288 , 288 ) ) ,
regnety_320 = _cfg ( url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth ' ) ,
regnety_040s_gn = _cfg ( url = ' ' ) ,
regnetz_005 = _cfg ( url = ' ' ) ,
regnetz_040 = _cfg ( url = ' ' , input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
)
@ -125,6 +162,40 @@ def generate_regnet(width_slope, width_initial, width_mult, depth, q=8):
return widths , num_stages , max_stage , widths_cont
def downsample_conv ( in_chs , out_chs , kernel_size = 1 , stride = 1 , dilation = 1 , norm_layer = None ) :
norm_layer = norm_layer or nn . BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
dilation = dilation if kernel_size > 1 else 1
return ConvNormAct (
in_chs , out_chs , kernel_size , stride = stride , dilation = dilation , norm_layer = norm_layer , apply_act = False )
def downsample_avg ( in_chs , out_chs , kernel_size = 1 , stride = 1 , dilation = 1 , norm_layer = None ) :
""" AvgPool Downsampling as in ' D ' ResNet variants. This is not in RegNet space but I might experiment. """
norm_layer = norm_layer or nn . BatchNorm2d
avg_stride = stride if dilation == 1 else 1
pool = nn . Identity ( )
if stride > 1 or dilation > 1 :
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn . AvgPool2d
pool = avg_pool_fn ( 2 , avg_stride , ceil_mode = True , count_include_pad = False )
return nn . Sequential ( * [
pool , ConvNormAct ( in_chs , out_chs , 1 , stride = 1 , norm_layer = norm_layer , apply_act = False ) ] )
def create_shortcut ( downsample_type , in_chs , out_chs , kernel_size , stride , dilation = ( 1 , 1 ) , norm_layer = None ) :
assert downsample_type in ( ' avg ' , ' conv1x1 ' , ' ' , None )
if in_chs != out_chs or stride != 1 or dilation [ 0 ] != dilation [ 1 ] :
if not downsample_type :
return None # no shortcut, no downsample
elif downsample_type == ' avg ' :
return downsample_avg ( in_chs , out_chs , stride = stride , dilation = dilation [ 0 ] , norm_layer = norm_layer )
else :
return downsample_conv (
in_chs , out_chs , kernel_size = kernel_size , stride = stride , dilation = dilation [ 0 ] , norm_layer = norm_layer )
else :
return nn . Identity ( ) # identity shortcut (no downsample)
class Bottleneck ( nn . Module ) :
""" RegNet Bottleneck
@ -132,97 +203,70 @@ class Bottleneck(nn.Module):
after conv3 to after conv2 . Otherwise , it ' s just redefining the arguments for groups/bottleneck channels.
"""
def __init__ ( self , in_chs , out_chs , stride = 1 , dilation = 1 , bottleneck_ratio = 1 , group_width = 1 , se_ratio = 0.25 ,
downsample = Non e, act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d , aa_layer = None ,
drop_block = None , drop_path = None ) :
def __init__ ( self , in_chs , out_chs , stride = 1 , dilation = ( 1 , 1 ) , bottle_ratio = 1 , group_size = 1 , se_ratio = 0.25 ,
downsample = ' conv1x1 ' , linear_out = Fals e, act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d ,
drop_block = None , drop_path _rate= 0. ) :
super ( Bottleneck , self ) . __init__ ( )
bottleneck_chs = int ( round ( out_chs * bottleneck_ratio ) )
groups = bottleneck_chs / / group_width
cargs = dict ( act_layer = act_layer , norm_layer = norm_layer , aa_layer = aa_layer , drop_block = drop_block )
self . conv1 = ConvBnAct ( in_chs , bottleneck_chs , kernel_size = 1 , * * cargs )
self . conv2 = ConvBnAct (
bottleneck_chs , bottleneck_chs , kernel_size = 3 , stride = stride , dilation = dilation ,
groups = groups , * * cargs )
act_layer = get_act_layer ( act_layer )
bottleneck_chs = int ( round ( out_chs * bottle_ratio ) )
groups = bottleneck_chs / / group_size
cargs = dict ( act_layer = act_layer , norm_layer = norm_layer )
self . conv1 = ConvNormAct ( in_chs , bottleneck_chs , kernel_size = 1 , * * cargs )
self . conv2 = ConvNormAct (
bottleneck_chs , bottleneck_chs , kernel_size = 3 , stride = stride , dilation = dilation [ 0 ] ,
groups = groups , drop_layer = drop_block , * * cargs )
if se_ratio :
se_channels = int ( round ( in_chs * se_ratio ) )
self . se = SEModule ( bottleneck_chs , rd_channels = se_channels )
self . se = SEModule ( bottleneck_chs , rd_channels = se_channels , act_layer = act_layer )
else :
self . se = None
cargs [ ' act_layer ' ] = None
self . conv3 = ConvBnAct ( bottleneck_chs , out_chs , kernel_size = 1 , * * cargs )
self . act3 = act_layer ( inplace = True )
self . downsample = downsample
self . drop_path = drop_path
def zero_init_last_bn ( self ) :
self . se = nn . Identity ( )
self . conv3 = ConvNormAct ( bottleneck_chs , out_chs , kernel_size = 1 , apply_act = False , * * cargs )
self . act3 = nn . Identity ( ) if linear_out else act_layer ( )
self . downsample = create_shortcut ( downsample , in_chs , out_chs , 1 , stride , dilation , norm_layer = norm_layer )
self . drop_path = DropPath ( drop_path_rate ) if drop_path_rate > 0 else nn . Identity ( )
def zero_init_last ( self ) :
nn . init . zeros_ ( self . conv3 . bn . weight )
def forward ( self , x ) :
shortcut = x
x = self . conv1 ( x )
x = self . conv2 ( x )
if self . se is not None :
x = self . se ( x )
x = self . se ( x )
x = self . conv3 ( x )
if self . drop_path is not None :
x = self . drop_path ( x )
if self . downsample is not None :
shortcut = self . downsample ( shortcut )
x + = shortcut
# NOTE stuck with downsample as the attr name due to weight compatibility
# now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity()
x = x + self . drop_path ( self . downsample ( shortcut ) )
x = self . act3 ( x )
return x
def downsample_conv (
in_chs , out_chs , kernel_size , stride = 1 , dilation = 1 , norm_layer = None ) :
norm_layer = norm_layer or nn . BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
dilation = dilation if kernel_size > 1 else 1
return ConvBnAct (
in_chs , out_chs , kernel_size , stride = stride , dilation = dilation , norm_layer = norm_layer , act_layer = None )
def downsample_avg (
in_chs , out_chs , kernel_size , stride = 1 , dilation = 1 , norm_layer = None ) :
""" AvgPool Downsampling as in ' D ' ResNet variants. This is not in RegNet space but I might experiment. """
norm_layer = norm_layer or nn . BatchNorm2d
avg_stride = stride if dilation == 1 else 1
pool = nn . Identity ( )
if stride > 1 or dilation > 1 :
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn . AvgPool2d
pool = avg_pool_fn ( 2 , avg_stride , ceil_mode = True , count_include_pad = False )
return nn . Sequential ( * [
pool , ConvBnAct ( in_chs , out_chs , 1 , stride = 1 , norm_layer = norm_layer , act_layer = None ) ] )
class RegStage ( nn . Module ) :
""" Stage (sequence of blocks w/ the same output shape). """
def __init__ ( self , in_chs , out_chs , stride , dilation , depth , bottle_ratio , group_width ,
block_fn = Bottleneck , se_ratio = 0. , drop_path_rates = None , drop_block = None ) :
def __init__ (
self , depth , in_chs , out_chs , stride , dilation , bottle_ratio = 1.0 , group_size = 8 , block_fn = Bottleneck ,
se_ratio = 0. , downsample = ' conv1x1 ' , linear_out = False , act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d ,
drop_path_rates = None , drop_block = None ) :
super ( RegStage , self ) . __init__ ( )
block_kwargs = { } # FIXME setup to pass various aa, norm, act layer common args
block_kwargs = dict (
bottle_ratio = bottle_ratio , group_size = group_size , se_ratio = se_ratio , downsample = downsample ,
linear_out = linear_out , act_layer = act_layer , norm_layer = norm_layer , drop_block = drop_block )
first_dilation = 1 if dilation in ( 1 , 2 ) else 2
for i in range ( depth ) :
block_stride = stride if i == 0 else 1
block_in_chs = in_chs if i == 0 else out_chs
block_dilation = first_dilation if i == 0 else dilation
if drop_path_rates is not None and drop_path_rates [ i ] > 0. :
drop_path = DropPath ( drop_path_rates [ i ] )
else :
drop_path = None
if ( block_in_chs != out_chs ) or ( block_stride != 1 ) :
proj_block = downsample_conv ( block_in_chs , out_chs , 1 , block_stride , block_dilation )
else :
proj_block = None
block_dilation = ( first_dilation , dilation )
dpr = drop_path_rates [ i ] if drop_path_rates is not None else 0.
name = " b {} " . format ( i + 1 )
self . add_module (
name , block_fn (
block_in_chs , out_chs , block_stride, block_dilation , bottle_ratio , group_width , se_ratio ,
d ownsample= proj_block , drop_block = drop_block , drop_path = drop_path , * * block_kwargs )
block_in_chs , out_chs , stride = block_stride , dilation = block_dilation ,
drop_path_rate = dpr , * * block_kwargs )
)
first_dilation = dilation
def forward ( self , x ) :
for block in self . children ( ) :
@ -231,33 +275,34 @@ class RegStage(nn.Module):
class RegNet ( nn . Module ) :
""" RegNet model.
""" RegNet -X, Y, and Z Models
Paper : https : / / arxiv . org / abs / 2003.13678
Original Impl : https : / / github . com / facebookresearch / pycls / blob / master / pycls / models / regnet . py
"""
def __init__ ( self , cfg , in_chans = 3 , num_classes = 1000 , output_stride = 32 , global_pool = ' avg ' , drop_rate = 0. ,
drop_path_rate = 0. , zero_init_last_bn = True ) :
def __init__ (
self , cfg : RegNetCfg , in_chans = 3 , num_classes = 1000 , output_stride = 32 , global_pool = ' avg ' ,
drop_rate = 0. , drop_path_rate = 0. , zero_init_last = True ) :
super ( ) . __init__ ( )
# TODO add drop block, drop path, anti-aliasing, custom bn/act args
self . num_classes = num_classes
self . drop_rate = drop_rate
assert output_stride in ( 8 , 16 , 32 )
# Construct the stem
stem_width = cfg [ ' stem_width ' ]
self . stem = Conv Bn Act( in_chans , stem_width , 3 , stride = 2 )
stem_width = cfg . stem_width
self . stem = Conv Norm Act( in_chans , stem_width , 3 , stride = 2 , act_layer = cfg . act_layer , norm_layer = cfg . norm_layer )
self . feature_info = [ dict ( num_chs = stem_width , reduction = 2 , module = ' stem ' ) ]
# Construct the stages
prev_width = stem_width
curr_stride = 2
stage_params = self . _get_stage_params ( cfg , output_stride = output_stride , drop_path_rate = drop_path_rate )
se_ratio = cfg [ ' se_ratio ' ]
for i , stage_args in enumerate ( stage_params ) :
stage_name = " s {} " . format ( i + 1 )
self . add_module ( stage_name , RegStage ( prev_width , * * stage_args , se_ratio = se_ratio ) )
self . add_module ( stage_name , RegStage (
in_chs = prev_width , se_ratio = cfg . se_ratio , downsample = cfg . downsample , linear_out = cfg . linear_out ,
act_layer = cfg . act_layer , norm_layer = cfg . norm_layer , * * stage_args ) )
prev_width = stage_args [ ' out_chs ' ]
curr_stride * = stage_args [ ' stride ' ]
self . feature_info + = [ dict ( num_chs = prev_width , reduction = curr_stride , module = stage_name ) ]
@ -267,31 +312,18 @@ class RegNet(nn.Module):
self . head = ClassifierHead (
in_chs = prev_width , num_classes = num_classes , pool_type = global_pool , drop_rate = drop_rate )
for m in self . modules ( ) :
if isinstance ( m , nn . Conv2d ) :
nn . init . kaiming_normal_ ( m . weight , mode = ' fan_out ' , nonlinearity = ' relu ' )
elif isinstance ( m , nn . BatchNorm2d ) :
nn . init . ones_ ( m . weight )
nn . init . zeros_ ( m . bias )
elif isinstance ( m , nn . Linear ) :
nn . init . normal_ ( m . weight , mean = 0.0 , std = 0.01 )
nn . init . zeros_ ( m . bias )
if zero_init_last_bn :
for m in self . modules ( ) :
if hasattr ( m , ' zero_init_last_bn ' ) :
m . zero_init_last_bn ( )
def _get_stage_params ( self , cfg , default_stride = 2 , output_stride = 32 , drop_path_rate = 0. ) :
named_apply ( partial ( _init_weights , zero_init_last = zero_init_last ) , self )
def _get_stage_params ( self , cfg : RegNetCfg , default_stride = 2 , output_stride = 32 , drop_path_rate = 0. ) :
# Generate RegNet ws per block
w_a , w_0 , w_m , d = cfg [ ' wa ' ] , cfg [ ' w0 ' ] , cfg [ ' wm ' ] , cfg [ ' depth ' ]
widths , num_stages , _ , _ = generate_regnet ( w_a , w_0 , w_m , d )
widths , num_stages , _ , _ = generate_regnet ( cfg . wa , cfg . w0 , cfg . wm , cfg . depth )
# Convert to per stage format
stage_widths , stage_depths = np . unique ( widths , return_counts = True )
# Use the same group width, bottleneck mult and stride for each stage
stage_groups = [ cfg [ ' group_w ' ] for _ in range ( num_stages ) ]
stage_bottle_ratios = [ cfg [ ' bottle_ratio ' ] for _ in range ( num_stages ) ]
stage_groups = [ cfg . group_size for _ in range ( num_stages ) ]
stage_bottle_ratios = [ cfg . bottle_ratio for _ in range ( num_stages ) ]
stage_strides = [ ]
stage_dilations = [ ]
net_stride = 2
@ -305,11 +337,11 @@ class RegNet(nn.Module):
net_stride * = stride
stage_strides . append ( stride )
stage_dilations . append ( dilation )
stage_dpr = np . split ( np . linspace ( 0 , drop_path_rate , d) , np . cumsum ( stage_depths [ : - 1 ] ) )
stage_dpr = np . split ( np . linspace ( 0 , drop_path_rate , cfg. depth ) , np . cumsum ( stage_depths [ : - 1 ] ) )
# Adjust the compatibility of ws and gws
stage_widths , stage_groups = adjust_widths_groups_comp ( stage_widths , stage_bottle_ratios , stage_groups )
param_names = [ ' out_chs ' , ' stride ' , ' dilation ' , ' depth ' , ' bottle_ratio ' , ' group_ width ' , ' drop_path_rates ' ]
param_names = [ ' out_chs ' , ' stride ' , ' dilation ' , ' depth ' , ' bottle_ratio ' , ' group_ size ' , ' drop_path_rates ' ]
stage_params = [
dict ( zip ( param_names , params ) ) for params in
zip ( stage_widths , stage_strides , stage_dilations , stage_depths , stage_bottle_ratios , stage_groups ,
@ -333,6 +365,19 @@ class RegNet(nn.Module):
return x
def _init_weights ( module , name = ' ' , zero_init_last = False ) :
if isinstance ( module , nn . Conv2d ) :
nn . init . kaiming_normal_ ( module . weight , mode = ' fan_out ' , nonlinearity = ' relu ' )
elif isinstance ( module , nn . BatchNorm2d ) :
nn . init . ones_ ( module . weight )
nn . init . zeros_ ( module . bias )
elif isinstance ( module , nn . Linear ) :
nn . init . normal_ ( module . weight , mean = 0.0 , std = 0.01 )
nn . init . zeros_ ( module . bias )
elif hasattr ( module , ' zero_init_last ' ) :
module . zero_init_last ( )
def _filter_fn ( state_dict ) :
""" convert patch embedding weight from manual patchify + linear proj to conv """
if ' model ' in state_dict :
@ -492,3 +537,27 @@ def regnety_160(pretrained=False, **kwargs):
def regnety_320 ( pretrained = False , * * kwargs ) :
""" RegNetY-32GF """
return _create_regnet ( ' regnety_320 ' , pretrained , * * kwargs )
@register_model
def regnety_040s_gn ( pretrained = False , * * kwargs ) :
""" RegNetY-4.0GF w/ GroupNorm """
return _create_regnet ( ' regnety_040s_gn ' , pretrained , * * kwargs )
@register_model
def regnetz_005 ( pretrained = False , * * kwargs ) :
""" RegNetZ-500MF
NOTE : config found in https : / / github . com / facebookresearch / ClassyVision / blob / main / classy_vision / models / regnet . py
but it ' s not clear it is equivalent to paper model as not detailed in the paper.
"""
return _create_regnet ( ' regnetz_005 ' , pretrained , * * kwargs )
@register_model
def regnetz_040 ( pretrained = False , * * kwargs ) :
""" RegNetZ-4.0GF
NOTE : config found in https : / / github . com / facebookresearch / ClassyVision / blob / main / classy_vision / models / regnet . py
but it ' s not clear it is equivalent to paper model as not detailed in the paper.
"""
return _create_regnet ( ' regnetz_040 ' , pretrained , * * kwargs )