Fix correctness of some group matching regex (no impact on result), some formatting, missed forward_head for resnet

pull/1014/head
Ross Wightman 2 years ago
parent 94bcdebd73
commit 0862e6ebae

@ -1004,9 +1004,10 @@ class BottleneckBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - kxk - 1x1 """ ResNet-like Bottleneck Block - 1x1 - kxk - 1x1
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, def __init__(
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False, self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.): downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(BottleneckBlock, self).__init__() super(BottleneckBlock, self).__init__()
layers = layers or LayerFn() layers = layers or LayerFn()
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
@ -1061,9 +1062,10 @@ class DarkBlock(nn.Module):
for more optimal compute. for more optimal compute.
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, def __init__(
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
drop_path_rate=0.): downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
super(DarkBlock, self).__init__() super(DarkBlock, self).__init__()
layers = layers or LayerFn() layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio) mid_chs = make_divisible(out_chs * bottle_ratio)
@ -1111,9 +1113,10 @@ class EdgeBlock(nn.Module):
FIXME is there a more common 3x3 + 1x1 conv block to name this after? FIXME is there a more common 3x3 + 1x1 conv block to name this after?
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, def __init__(
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
drop_block=None, drop_path_rate=0.): downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
super(EdgeBlock, self).__init__() super(EdgeBlock, self).__init__()
layers = layers or LayerFn() layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio) mid_chs = make_divisible(out_chs * bottle_ratio)
@ -1158,8 +1161,9 @@ class RepVggBlock(nn.Module):
This version does not currently support the deploy optimization. It is currently fixed in 'train' mode. This version does not currently support the deploy optimization. It is currently fixed in 'train' mode.
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, def __init__(
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.): self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(RepVggBlock, self).__init__() super(RepVggBlock, self).__init__()
layers = layers or LayerFn() layers = layers or LayerFn()
groups = num_groups(group_size, in_chs) groups = num_groups(group_size, in_chs)
@ -1522,7 +1526,7 @@ class ByobNet(nn.Module):
matcher = dict( matcher = dict(
stem=r'^stem', stem=r'^stem',
blocks=[ blocks=[
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).(\d+)', None), (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
(r'^final_conv', (99999,)) (r'^final_conv', (99999,))
] ]
) )

@ -164,8 +164,9 @@ class CrossAttention(nn.Module):
class CrossAttentionBlock(nn.Module): class CrossAttentionBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., def __init__(
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = CrossAttention( self.attn = CrossAttention(

@ -157,9 +157,10 @@ class ResBottleneck(nn.Module):
""" ResNe(X)t Bottleneck Block """ ResNe(X)t Bottleneck Block
""" """
def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1, def __init__(
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False, self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(ResBottleneck, self).__init__() super(ResBottleneck, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio)) mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
@ -199,9 +200,10 @@ class DarkBlock(nn.Module):
""" DarkNet Block """ DarkNet Block
""" """
def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1, def __init__(
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1,
drop_block=None, drop_path=None): act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None,
drop_block=None, drop_path=None):
super(DarkBlock, self).__init__() super(DarkBlock, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio)) mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
@ -229,9 +231,10 @@ class DarkBlock(nn.Module):
class CrossStage(nn.Module): class CrossStage(nn.Module):
"""Cross Stage.""" """Cross Stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1., def __init__(
groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None, self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1.,
block_fn=ResBottleneck, **block_kwargs): groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None,
block_fn=ResBottleneck, **block_kwargs):
super(CrossStage, self).__init__() super(CrossStage, self).__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
@ -280,8 +283,9 @@ class CrossStage(nn.Module):
class DarkStage(nn.Module): class DarkStage(nn.Module):
"""DarkNet stage.""" """DarkNet stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1, def __init__(
first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs): self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1,
first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs):
super(DarkStage, self).__init__() super(DarkStage, self).__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
@ -387,10 +391,10 @@ class CspNet(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else [ blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages.(\d+).blocks.(\d+)', None), (r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^stages.(\d+).*transition', MATCH_PREV_GROUP), # map to last block in stage (r'^stages\.(\d+)\..*transition', MATCH_PREV_GROUP), # map to last block in stage
(r'^stages.(\d+)', (0,)), (r'^stages\.(\d+)', (0,)),
] ]
) )
return matcher return matcher

@ -85,7 +85,7 @@ class VisionTransformerDistilled(VisionTransformer):
return dict( return dict(
stem=r'^cls_token|pos_embed|patch_embed|dist_token', stem=r'^cls_token|pos_embed|patch_embed|dist_token',
blocks=[ blocks=[
(r'^blocks.(\d+)', None), (r'^blocks\.(\d+)', None),
(r'^norm', (99999,))] # final norm w/ last block (r'^norm', (99999,))] # final norm w/ last block
) )

@ -45,8 +45,9 @@ default_cfgs = {
class DenseLayer(nn.Module): class DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d, def __init__(
drop_rate=0., memory_efficient=False): self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
super(DenseLayer, self).__init__() super(DenseLayer, self).__init__()
self.add_module('norm1', norm_layer(num_input_features)), self.add_module('norm1', norm_layer(num_input_features)),
self.add_module('conv1', nn.Conv2d( self.add_module('conv1', nn.Conv2d(
@ -113,8 +114,9 @@ class DenseLayer(nn.Module):
class DenseBlock(nn.ModuleDict): class DenseBlock(nn.ModuleDict):
_version = 2 _version = 2
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, def __init__(
drop_rate=0., memory_efficient=False): self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
drop_rate=0., memory_efficient=False):
super(DenseBlock, self).__init__() super(DenseBlock, self).__init__()
for i in range(num_layers): for i in range(num_layers):
layer = DenseLayer( layer = DenseLayer(
@ -164,8 +166,8 @@ class DenseNet(nn.Module):
def __init__( def __init__(
self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg', self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg',
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False, bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0,
aa_stem_only=True): memory_efficient=False, aa_stem_only=True):
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
super(DenseNet, self).__init__() super(DenseNet, self).__init__()
@ -252,10 +254,10 @@ class DenseNet(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^features.conv[012]|features.norm[012]|features.pool[012]', stem=r'^features\.conv[012]|features\.norm[012]|features\.pool[012]',
blocks=r'^features.(?:denseblock|transition)(\d+)' if coarse else [ blocks=r'^features\.(?:denseblock|transition)(\d+)' if coarse else [
(r'^features.denseblock(\d+).denselayer(\d+)', None), (r'^features\.denseblock(\d+)\.denselayer(\d+)', None),
(r'^features.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer (r'^features\.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer
] ]
) )
return matcher return matcher

@ -323,8 +323,8 @@ class DLA(nn.Module):
stem=r'^base_layer', stem=r'^base_layer',
blocks=r'^level(\d+)' if coarse else [ blocks=r'^level(\d+)' if coarse else [
# an unusual arch, this achieves somewhat more granularity without getting super messy # an unusual arch, this achieves somewhat more granularity without getting super messy
(r'^level(\d+).tree(\d+)', None), (r'^level(\d+)\.tree(\d+)', None),
(r'^level(\d+).root', (2,)), (r'^level(\d+)\.root', (2,)),
(r'^level(\d+)', (1,)) (r'^level(\d+)', (1,))
] ]
) )

@ -243,10 +243,10 @@ class DPN(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^features.conv1', stem=r'^features\.conv1',
blocks=[ blocks=[
(r'^features.conv(\d+)' if coarse else r'^features.conv(\d+)_(\d+)', None), (r'^features\.conv(\d+)' if coarse else r'^features\.conv(\d+)_(\d+)', None),
(r'^features.conv5_bn_ac', (99999,)) (r'^features\.conv5_bn_ac', (99999,))
] ]
) )
return matcher return matcher

@ -518,7 +518,7 @@ class EfficientNet(nn.Module):
return dict( return dict(
stem=r'^conv_stem|bn1', stem=r'^conv_stem|bn1',
blocks=[ blocks=[
(r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None), (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
(r'conv_head|bn2', (99999,)) (r'conv_head|bn2', (99999,))
] ]
) )

@ -193,7 +193,7 @@ class GhostNet(nn.Module):
matcher = dict( matcher = dict(
stem=r'^conv_stem|bn1', stem=r'^conv_stem|bn1',
blocks=[ blocks=[
(r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None), (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
(r'conv_head', (99999,)) (r'conv_head', (99999,))
] ]
) )

@ -184,7 +184,7 @@ class Xception65(nn.Module):
matcher = dict( matcher = dict(
stem=r'^conv[12]|bn[12]', stem=r'^conv[12]|bn[12]',
blocks=[ blocks=[
(r'^mid.block(\d+)', None), (r'^mid\.block(\d+)', None),
(r'^block(\d+)', None), (r'^block(\d+)', None),
(r'^conv[345]|bn[345]', (99,)), (r'^conv[345]|bn[345]', (99,)),
], ],

@ -686,8 +686,8 @@ class HighResolutionNet(nn.Module):
matcher = dict( matcher = dict(
stem=r'^conv[12]|bn[12]', stem=r'^conv[12]|bn[12]',
blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [ blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [
(r'^layer(\d+).(\d+)', None), (r'^layer(\d+)\.(\d+)', None),
(r'^stage(\d+).(\d+)', None), (r'^stage(\d+)\.(\d+)', None),
(r'^transition(\d+)', (99999,)), (r'^transition(\d+)', (99999,)),
], ],
) )

@ -496,7 +496,7 @@ class Levit(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
) )
return matcher return matcher

@ -291,7 +291,7 @@ class MlpMixer(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^stem', # stem and embed stem=r'^stem', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
) )
@torch.jit.ignore @torch.jit.ignore

@ -171,7 +171,7 @@ class MobileNetV3(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^conv_stem|bn1', stem=r'^conv_stem|bn1',
blocks=r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)' blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)'
) )
@torch.jit.ignore @torch.jit.ignore

@ -334,8 +334,8 @@ class Nest(nn.Module):
matcher = dict( matcher = dict(
stem=r'^patch_embed', # stem and embed stem=r'^patch_embed', # stem and embed
blocks=[ blocks=[
(r'^levels.(\d+)' if coarse else r'^levels.(\d+).transformer_encoder.(\d+)', None), (r'^levels\.(\d+)' if coarse else r'^levels\.(\d+)\.transformer_encoder\.(\d+)', None),
(r'^levels.(\d+).(?:pool|pos_embed)', (0,)), (r'^levels\.(\d+)\.(?:pool|pos_embed)', (0,)),
(r'^norm', (99999,)) (r'^norm', (99999,))
] ]
) )

@ -194,7 +194,6 @@ def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', ski
return cfg return cfg
model_cfgs = dict( model_cfgs = dict(
# NFNet-F models w/ GELU compatible with DeepMind weights # NFNet-F models w/ GELU compatible with DeepMind weights
dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)), dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
@ -550,7 +549,7 @@ class NormFreeNet(nn.Module):
matcher = dict( matcher = dict(
stem=r'^stem', stem=r'^stem',
blocks=[ blocks=[
(r'^stages.(\d+)' if coarse else r'^stages.(\d+).(\d+)', None), (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
(r'^final_conv', (99999,)) (r'^final_conv', (99999,))
] ]
) )

@ -458,7 +458,7 @@ class RegNet(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)', blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.blocks\.(\d+)',
) )
@torch.jit.ignore @torch.jit.ignore

@ -315,9 +315,10 @@ def create_aa(aa_layer, channels, stride=2, enable=True):
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
expansion = 1 expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, def __init__(
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
@ -379,9 +380,10 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
expansion = 4 expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, def __init__(
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality) width = int(math.floor(planes * (base_width / 64)) * cardinality)
@ -561,48 +563,35 @@ class ResNet(nn.Module):
Parameters Parameters
---------- ----------
block : Block block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl.
Class for the residual block. Options are BasicBlockGl, BottleneckGl. layers : list of int, number of layers in each block
layers : list of int num_classes : int, default 1000, number of classification classes.
Numbers of layers in each block in_chans : int, default 3, number of input (color) channels.
num_classes : int, default 1000 output_stride : int, default 32, output stride of the network, 32, 16, or 8.
Number of classification classes. global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
in_chans : int, default 3 cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck.
Number of input (color) channels. base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
cardinality : int, default 1 stem_width : int, default 64, number of channels in stem convolutions
Number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64
Factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
stem_width : int, default 64
Number of channels in stem convolutions
stem_type : str, default '' stem_type : str, default ''
The type of stem: The type of stem:
* '', default - a single 7x7 conv with a width of stem_width * '', default - a single 7x7 conv with a width of stem_width
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
block_reduce_first: int, default 1 block_reduce_first : int, default 1
Reduction factor for first convolution output width of residual blocks, Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2
1 for all archs except senets, where 2 down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets
down_kernel_size: int, default 1 avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample.
Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets
avg_down : bool, default False
Whether to use average pooling for projection skip connection between stages/downsample.
output_stride : int, default 32
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
act_layer : nn.Module, activation layer act_layer : nn.Module, activation layer
norm_layer : nn.Module, normalization layer norm_layer : nn.Module, normalization layer
aa_layer : nn.Module, anti-aliasing layer aa_layer : nn.Module, anti-aliasing layer
drop_rate : float, default 0. drop_rate : float, default 0. Dropout probability before classifier, for training
Dropout probability before classifier, for training
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
""" """
def __init__(self, block, layers, num_classes=1000, in_chans=3, def __init__(
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, self, block, layers, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, block_reduce_first=1,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None,
drop_block_rate=0., global_pool='avg', zero_init_last=True, block_args=None): drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., zero_init_last=True, block_args=None):
super(ResNet, self).__init__() super(ResNet, self).__init__()
block_args = block_args or dict() block_args = block_args or dict()
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
@ -712,12 +701,15 @@ class ResNet(nn.Module):
x = self.layer4(x) x = self.layer4(x)
return x return x
def forward(self, x): def forward_head(self, x, pre_logits: bool = False):
x = self.forward_features(x)
x = self.global_pool(x) x = self.global_pool(x)
if self.drop_rate: if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x) return x if pre_logits else self.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x return x

@ -411,8 +411,8 @@ class ResNetV2(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else [ blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages.(\d+).blocks.(\d+)', None), (r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm', (99999,)) (r'^norm', (99999,))
] ]
) )

@ -173,7 +173,7 @@ class ReXNetV1(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^features.(\d+)', blocks=r'^features\.(\d+)',
) )
return matcher return matcher

@ -360,7 +360,7 @@ class SENet(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+).(\d+)') matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
return matcher return matcher
@torch.jit.ignore @torch.jit.ignore

@ -525,9 +525,9 @@ class SwinTransformer(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^absolute_pos_embed|patch_embed', # stem and embed stem=r'^absolute_pos_embed|patch_embed', # stem and embed
blocks=r'^layers.(\d+)' if coarse else [ blocks=r'^layers\.(\d+)' if coarse else [
(r'^layers.(\d+).downsample', (0,)), (r'^layers\.(\d+).downsample', (0,)),
(r'^layers.(\d+).\w+.(\d+)', None), (r'^layers\.(\d+)\.\w+\.(\d+)', None),
(r'^norm', (99999,)), (r'^norm', (99999,)),
] ]
) )

@ -217,7 +217,7 @@ class TNT(nn.Module):
matcher = dict( matcher = dict(
stem=r'^cls_token|patch_pos|pixel_pos|pixel_embed|norm[12]_proj|proj', # stem and embed / pos stem=r'^cls_token|patch_pos|pixel_pos|pixel_embed|norm[12]_proj|proj', # stem and embed / pos
blocks=[ blocks=[
(r'^blocks.(\d+)', None), (r'^blocks\.(\d+)', None),
(r'^norm', (99999,)), (r'^norm', (99999,)),
] ]
) )

@ -233,7 +233,7 @@ class TResNet(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict(stem=r'^body.conv1', blocks=r'^body.layer(\d+)' if coarse else r'^body.layer(\d+).(\d+)') matcher = dict(stem=r'^body\.conv1', blocks=r'^body\.layer(\d+)' if coarse else r'^body\.layer(\d+)\.(\d+)')
return matcher return matcher
@torch.jit.ignore @torch.jit.ignore

@ -327,11 +327,11 @@ class Twins(nn.Module):
matcher = dict( matcher = dict(
stem=r'^patch_embeds.0', # stem and embed stem=r'^patch_embeds.0', # stem and embed
blocks=[ blocks=[
(r'^(?:blocks|patch_embeds|pos_block).(\d+)', None), (r'^(?:blocks|patch_embeds|pos_block)\.(\d+)', None),
('^norm', (99999,)) ('^norm', (99999,))
] if coarse else [ ] if coarse else [
(r'^blocks.(\d+).(\d+)', None), (r'^blocks\.(\d+)\.(\d+)', None),
(r'^(?:patch_embeds|pos_block).(\d+)', (0,)), (r'^(?:patch_embeds|pos_block)\.(\d+)', (0,)),
(r'^norm', (99999,)) (r'^norm', (99999,))
] ]
) )

@ -136,7 +136,7 @@ class VGG(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
# this treats BN layers as separate groups for bn variants, a lot of effort to fix that # this treats BN layers as separate groups for bn variants, a lot of effort to fix that
return dict(stem=r'^features.0', blocks=r'^features.(\d+)') return dict(stem=r'^features\.0', blocks=r'^features\.(\d+)')
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable=True): def set_grad_checkpointing(self, enable=True):

@ -271,7 +271,7 @@ class Visformer(nn.Module):
return dict( return dict(
stem=r'^patch_embed1|pos_embed1|stem', # stem and embed stem=r'^patch_embed1|pos_embed1|stem', # stem and embed
blocks=[ blocks=[
(r'^stage(\d+).(\d+)' if coarse else r'^stage(\d+).(\d+)', None), (r'^stage(\d+)\.(\d+)' if coarse else r'^stage(\d+)\.(\d+)', None),
(r'^(?:patch_embed|pos_embed)(\d+)', (0,)), (r'^(?:patch_embed|pos_embed)(\d+)', (0,)),
(r'^norm', (99999,)) (r'^norm', (99999,))
] ]

@ -331,7 +331,7 @@ class VisionTransformer(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
) )
@torch.jit.ignore @torch.jit.ignore

@ -327,7 +327,7 @@ class VovNet(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)', blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)',
) )
@torch.jit.ignore @torch.jit.ignore

@ -221,7 +221,7 @@ class XceptionAligned(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^blocks.(\d+)', blocks=r'^blocks\.(\d+)',
) )
@torch.jit.ignore @torch.jit.ignore

@ -412,8 +412,8 @@ class XCiT(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=r'^blocks.(\d+)', blocks=r'^blocks\.(\d+)',
cls_attn_blocks=[(r'^cls_attn_blocks.(\d+)', None), (r'^norm', (99999,))] cls_attn_blocks=[(r'^cls_attn_blocks\.(\d+)', None), (r'^norm', (99999,))]
) )
@torch.jit.ignore @torch.jit.ignore

Loading…
Cancel
Save