diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index c24ca0e7..1e402629 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -1004,9 +1004,10 @@ class BottleneckBlock(nn.Module): """ ResNet-like Bottleneck Block - 1x1 - kxk - 1x1 """ - def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False, - layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + def __init__( + self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, + downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False, + layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(BottleneckBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) @@ -1061,9 +1062,10 @@ class DarkBlock(nn.Module): for more optimal compute. """ - def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, - drop_path_rate=0.): + def __init__( + self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, + downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, + drop_path_rate=0.): super(DarkBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -1111,9 +1113,10 @@ class EdgeBlock(nn.Module): FIXME is there a more common 3x3 + 1x1 conv block to name this after? """ - def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, - drop_block=None, drop_path_rate=0.): + def __init__( + self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, + downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, + drop_block=None, drop_path_rate=0.): super(EdgeBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -1158,8 +1161,9 @@ class RepVggBlock(nn.Module): This version does not currently support the deploy optimization. It is currently fixed in 'train' mode. """ - def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + def __init__( + self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, + downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(RepVggBlock, self).__init__() layers = layers or LayerFn() groups = num_groups(group_size, in_chs) @@ -1522,7 +1526,7 @@ class ByobNet(nn.Module): matcher = dict( stem=r'^stem', blocks=[ - (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).(\d+)', None), + (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None), (r'^final_conv', (99999,)) ] ) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 5a3260bf..764eb3fe 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -164,8 +164,9 @@ class CrossAttention(nn.Module): class CrossAttentionBlock(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = CrossAttention( diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 75c525bf..f8a87fab 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -157,9 +157,10 @@ class ResBottleneck(nn.Module): """ ResNe(X)t Bottleneck Block """ - def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + def __init__( + self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): super(ResBottleneck, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) @@ -199,9 +200,10 @@ class DarkBlock(nn.Module): """ DarkNet Block """ - def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, - drop_block=None, drop_path=None): + def __init__( + self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, + drop_block=None, drop_path=None): super(DarkBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) @@ -229,9 +231,10 @@ class DarkBlock(nn.Module): class CrossStage(nn.Module): """Cross Stage.""" - def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1., - groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None, - block_fn=ResBottleneck, **block_kwargs): + def __init__( + self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1., + groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None, + block_fn=ResBottleneck, **block_kwargs): super(CrossStage, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels @@ -280,8 +283,9 @@ class CrossStage(nn.Module): class DarkStage(nn.Module): """DarkNet stage.""" - def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1, - first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs): + def __init__( + self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1, + first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs): super(DarkStage, self).__init__() first_dilation = first_dilation or dilation @@ -387,10 +391,10 @@ class CspNet(nn.Module): def group_matcher(self, coarse=False): matcher = dict( stem=r'^stem', - blocks=r'^stages.(\d+)' if coarse else [ - (r'^stages.(\d+).blocks.(\d+)', None), - (r'^stages.(\d+).*transition', MATCH_PREV_GROUP), # map to last block in stage - (r'^stages.(\d+)', (0,)), + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^stages\.(\d+)\..*transition', MATCH_PREV_GROUP), # map to last block in stage + (r'^stages\.(\d+)', (0,)), ] ) return matcher diff --git a/timm/models/deit.py b/timm/models/deit.py index 3fd8655b..1251c373 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -85,7 +85,7 @@ class VisionTransformerDistilled(VisionTransformer): return dict( stem=r'^cls_token|pos_embed|patch_embed|dist_token', blocks=[ - (r'^blocks.(\d+)', None), + (r'^blocks\.(\d+)', None), (r'^norm', (99999,))] # final norm w/ last block ) diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 304eda79..a46b86ad 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -45,8 +45,9 @@ default_cfgs = { class DenseLayer(nn.Module): - def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d, - drop_rate=0., memory_efficient=False): + def __init__( + self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d, + drop_rate=0., memory_efficient=False): super(DenseLayer, self).__init__() self.add_module('norm1', norm_layer(num_input_features)), self.add_module('conv1', nn.Conv2d( @@ -113,8 +114,9 @@ class DenseLayer(nn.Module): class DenseBlock(nn.ModuleDict): _version = 2 - def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, - drop_rate=0., memory_efficient=False): + def __init__( + self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, + drop_rate=0., memory_efficient=False): super(DenseBlock, self).__init__() for i in range(num_layers): layer = DenseLayer( @@ -164,8 +166,8 @@ class DenseNet(nn.Module): def __init__( self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg', - bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False, - aa_stem_only=True): + bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, + memory_efficient=False, aa_stem_only=True): self.num_classes = num_classes self.drop_rate = drop_rate super(DenseNet, self).__init__() @@ -252,10 +254,10 @@ class DenseNet(nn.Module): @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict( - stem=r'^features.conv[012]|features.norm[012]|features.pool[012]', - blocks=r'^features.(?:denseblock|transition)(\d+)' if coarse else [ - (r'^features.denseblock(\d+).denselayer(\d+)', None), - (r'^features.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer + stem=r'^features\.conv[012]|features\.norm[012]|features\.pool[012]', + blocks=r'^features\.(?:denseblock|transition)(\d+)' if coarse else [ + (r'^features\.denseblock(\d+)\.denselayer(\d+)', None), + (r'^features\.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer ] ) return matcher diff --git a/timm/models/dla.py b/timm/models/dla.py index bc1a7394..6ab1802d 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -323,8 +323,8 @@ class DLA(nn.Module): stem=r'^base_layer', blocks=r'^level(\d+)' if coarse else [ # an unusual arch, this achieves somewhat more granularity without getting super messy - (r'^level(\d+).tree(\d+)', None), - (r'^level(\d+).root', (2,)), + (r'^level(\d+)\.tree(\d+)', None), + (r'^level(\d+)\.root', (2,)), (r'^level(\d+)', (1,)) ] ) diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 616efdbb..95159729 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -243,10 +243,10 @@ class DPN(nn.Module): @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict( - stem=r'^features.conv1', + stem=r'^features\.conv1', blocks=[ - (r'^features.conv(\d+)' if coarse else r'^features.conv(\d+)_(\d+)', None), - (r'^features.conv5_bn_ac', (99999,)) + (r'^features\.conv(\d+)' if coarse else r'^features\.conv(\d+)_(\d+)', None), + (r'^features\.conv5_bn_ac', (99999,)) ] ) return matcher diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 96c30c76..0500f76e 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -518,7 +518,7 @@ class EfficientNet(nn.Module): return dict( stem=r'^conv_stem|bn1', blocks=[ - (r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None), + (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None), (r'conv_head|bn2', (99999,)) ] ) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index bedb04a5..e19af88b 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -193,7 +193,7 @@ class GhostNet(nn.Module): matcher = dict( stem=r'^conv_stem|bn1', blocks=[ - (r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None), + (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None), (r'conv_head', (99999,)) ] ) diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 17a197a0..a9c946b2 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -184,7 +184,7 @@ class Xception65(nn.Module): matcher = dict( stem=r'^conv[12]|bn[12]', blocks=[ - (r'^mid.block(\d+)', None), + (r'^mid\.block(\d+)', None), (r'^block(\d+)', None), (r'^conv[345]|bn[345]', (99,)), ], diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 1a53f44d..7e9b096f 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -686,8 +686,8 @@ class HighResolutionNet(nn.Module): matcher = dict( stem=r'^conv[12]|bn[12]', blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [ - (r'^layer(\d+).(\d+)', None), - (r'^stage(\d+).(\d+)', None), + (r'^layer(\d+)\.(\d+)', None), + (r'^stage(\d+)\.(\d+)', None), (r'^transition(\d+)', (99999,)), ], ) diff --git a/timm/models/levit.py b/timm/models/levit.py index b1dae17a..e93662ae 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -496,7 +496,7 @@ class Levit(nn.Module): def group_matcher(self, coarse=False): matcher = dict( stem=r'^cls_token|pos_embed|patch_embed', # stem and embed - blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) return matcher diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 75cdf84b..ff91def6 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -291,7 +291,7 @@ class MlpMixer(nn.Module): def group_matcher(self, coarse=False): return dict( stem=r'^stem', # stem and embed - blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @torch.jit.ignore diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 79e468a0..4a791857 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -171,7 +171,7 @@ class MobileNetV3(nn.Module): def group_matcher(self, coarse=False): return dict( stem=r'^conv_stem|bn1', - blocks=r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)' + blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)' ) @torch.jit.ignore diff --git a/timm/models/nest.py b/timm/models/nest.py index 655cd755..8692a2b1 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -334,8 +334,8 @@ class Nest(nn.Module): matcher = dict( stem=r'^patch_embed', # stem and embed blocks=[ - (r'^levels.(\d+)' if coarse else r'^levels.(\d+).transformer_encoder.(\d+)', None), - (r'^levels.(\d+).(?:pool|pos_embed)', (0,)), + (r'^levels\.(\d+)' if coarse else r'^levels\.(\d+)\.transformer_encoder\.(\d+)', None), + (r'^levels\.(\d+)\.(?:pool|pos_embed)', (0,)), (r'^norm', (99999,)) ] ) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 4b79da50..3a45410b 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -194,7 +194,6 @@ def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', ski return cfg - model_cfgs = dict( # NFNet-F models w/ GELU compatible with DeepMind weights dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)), @@ -550,7 +549,7 @@ class NormFreeNet(nn.Module): matcher = dict( stem=r'^stem', blocks=[ - (r'^stages.(\d+)' if coarse else r'^stages.(\d+).(\d+)', None), + (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None), (r'^final_conv', (99999,)) ] ) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 4cf4a698..87ea32a6 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -458,7 +458,7 @@ class RegNet(nn.Module): def group_matcher(self, coarse=False): return dict( stem=r'^stem', - blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)', + blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.blocks\.(\d+)', ) @torch.jit.ignore diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 33366ae7..7a5afb3b 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -315,9 +315,10 @@ def create_aa(aa_layer, channels, stride=2, enable=True): class BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + def __init__( + self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' @@ -379,9 +380,10 @@ class BasicBlock(nn.Module): class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + def __init__( + self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) @@ -561,48 +563,35 @@ class ResNet(nn.Module): Parameters ---------- - block : Block - Class for the residual block. Options are BasicBlockGl, BottleneckGl. - layers : list of int - Numbers of layers in each block - num_classes : int, default 1000 - Number of classification classes. - in_chans : int, default 3 - Number of input (color) channels. - cardinality : int, default 1 - Number of convolution groups for 3x3 conv in Bottleneck. - base_width : int, default 64 - Factor determining bottleneck channels. `planes * base_width / 64 * cardinality` - stem_width : int, default 64 - Number of channels in stem convolutions + block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl. + layers : list of int, number of layers in each block + num_classes : int, default 1000, number of classification classes. + in_chans : int, default 3, number of input (color) channels. + output_stride : int, default 32, output stride of the network, 32, 16, or 8. + global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck. + base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality` + stem_width : int, default 64, number of channels in stem convolutions stem_type : str, default '' The type of stem: * '', default - a single 7x7 conv with a width of stem_width * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 - block_reduce_first: int, default 1 - Reduction factor for first convolution output width of residual blocks, - 1 for all archs except senets, where 2 - down_kernel_size: int, default 1 - Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets - avg_down : bool, default False - Whether to use average pooling for projection skip connection between stages/downsample. - output_stride : int, default 32 - Set the output stride of the network, 32, 16, or 8. Typically used in segmentation. + block_reduce_first : int, default 1 + Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2 + down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets + avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample. act_layer : nn.Module, activation layer norm_layer : nn.Module, normalization layer aa_layer : nn.Module, anti-aliasing layer - drop_rate : float, default 0. - Dropout probability before classifier, for training - global_pool : str, default 'avg' - Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + drop_rate : float, default 0. Dropout probability before classifier, for training """ - def __init__(self, block, layers, num_classes=1000, in_chans=3, - cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, - output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., - drop_block_rate=0., global_pool='avg', zero_init_last=True, block_args=None): + def __init__( + self, block, layers, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg', + cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, block_reduce_first=1, + down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, + drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., zero_init_last=True, block_args=None): super(ResNet, self).__init__() block_args = block_args or dict() assert output_stride in (8, 16, 32) @@ -712,12 +701,15 @@ class ResNet(nn.Module): x = self.layer4(x) return x - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) - x = self.fc(x) + return x if pre_logits else self.fc(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 09b6207a..bde088db 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -411,8 +411,8 @@ class ResNetV2(nn.Module): def group_matcher(self, coarse=False): matcher = dict( stem=r'^stem', - blocks=r'^stages.(\d+)' if coarse else [ - (r'^stages.(\d+).blocks.(\d+)', None), + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.blocks\.(\d+)', None), (r'^norm', (99999,)) ] ) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 902d344f..33e97222 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -173,7 +173,7 @@ class ReXNetV1(nn.Module): def group_matcher(self, coarse=False): matcher = dict( stem=r'^stem', - blocks=r'^features.(\d+)', + blocks=r'^features\.(\d+)', ) return matcher diff --git a/timm/models/senet.py b/timm/models/senet.py index 7a7a5e1c..97e592e4 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -360,7 +360,7 @@ class SENet(nn.Module): @torch.jit.ignore def group_matcher(self, coarse=False): - matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+).(\d+)') + matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)') return matcher @torch.jit.ignore diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 79d36b65..b8262749 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -525,9 +525,9 @@ class SwinTransformer(nn.Module): def group_matcher(self, coarse=False): return dict( stem=r'^absolute_pos_embed|patch_embed', # stem and embed - blocks=r'^layers.(\d+)' if coarse else [ - (r'^layers.(\d+).downsample', (0,)), - (r'^layers.(\d+).\w+.(\d+)', None), + blocks=r'^layers\.(\d+)' if coarse else [ + (r'^layers\.(\d+).downsample', (0,)), + (r'^layers\.(\d+)\.\w+\.(\d+)', None), (r'^norm', (99999,)), ] ) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 63107a27..5b72b196 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -217,7 +217,7 @@ class TNT(nn.Module): matcher = dict( stem=r'^cls_token|patch_pos|pixel_pos|pixel_embed|norm[12]_proj|proj', # stem and embed / pos blocks=[ - (r'^blocks.(\d+)', None), + (r'^blocks\.(\d+)', None), (r'^norm', (99999,)), ] ) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index f5a1c99a..0457acf8 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -233,7 +233,7 @@ class TResNet(nn.Module): @torch.jit.ignore def group_matcher(self, coarse=False): - matcher = dict(stem=r'^body.conv1', blocks=r'^body.layer(\d+)' if coarse else r'^body.layer(\d+).(\d+)') + matcher = dict(stem=r'^body\.conv1', blocks=r'^body\.layer(\d+)' if coarse else r'^body\.layer(\d+)\.(\d+)') return matcher @torch.jit.ignore diff --git a/timm/models/twins.py b/timm/models/twins.py index c6ca03ff..0626db37 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -327,11 +327,11 @@ class Twins(nn.Module): matcher = dict( stem=r'^patch_embeds.0', # stem and embed blocks=[ - (r'^(?:blocks|patch_embeds|pos_block).(\d+)', None), + (r'^(?:blocks|patch_embeds|pos_block)\.(\d+)', None), ('^norm', (99999,)) ] if coarse else [ - (r'^blocks.(\d+).(\d+)', None), - (r'^(?:patch_embeds|pos_block).(\d+)', (0,)), + (r'^blocks\.(\d+)\.(\d+)', None), + (r'^(?:patch_embeds|pos_block)\.(\d+)', (0,)), (r'^norm', (99999,)) ] ) diff --git a/timm/models/vgg.py b/timm/models/vgg.py index f671de22..caf96517 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -136,7 +136,7 @@ class VGG(nn.Module): @torch.jit.ignore def group_matcher(self, coarse=False): # this treats BN layers as separate groups for bn variants, a lot of effort to fix that - return dict(stem=r'^features.0', blocks=r'^features.(\d+)') + return dict(stem=r'^features\.0', blocks=r'^features\.(\d+)') @torch.jit.ignore def set_grad_checkpointing(self, enable=True): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 112f888b..254a0748 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -271,7 +271,7 @@ class Visformer(nn.Module): return dict( stem=r'^patch_embed1|pos_embed1|stem', # stem and embed blocks=[ - (r'^stage(\d+).(\d+)' if coarse else r'^stage(\d+).(\d+)', None), + (r'^stage(\d+)\.(\d+)' if coarse else r'^stage(\d+)\.(\d+)', None), (r'^(?:patch_embed|pos_embed)(\d+)', (0,)), (r'^norm', (99999,)) ] diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 1d6e79d8..79778ab1 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -331,7 +331,7 @@ class VisionTransformer(nn.Module): def group_matcher(self, coarse=False): return dict( stem=r'^cls_token|pos_embed|patch_embed', # stem and embed - blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @torch.jit.ignore diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 59ee470f..39d37195 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -327,7 +327,7 @@ class VovNet(nn.Module): def group_matcher(self, coarse=False): return dict( stem=r'^stem', - blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)', + blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)', ) @torch.jit.ignore diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 52fe57da..6bbce5e6 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -221,7 +221,7 @@ class XceptionAligned(nn.Module): def group_matcher(self, coarse=False): return dict( stem=r'^stem', - blocks=r'^blocks.(\d+)', + blocks=r'^blocks\.(\d+)', ) @torch.jit.ignore diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 7782d721..69b97d64 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -412,8 +412,8 @@ class XCiT(nn.Module): def group_matcher(self, coarse=False): return dict( stem=r'^cls_token|pos_embed|patch_embed', # stem and embed - blocks=r'^blocks.(\d+)', - cls_attn_blocks=[(r'^cls_attn_blocks.(\d+)', None), (r'^norm', (99999,))] + blocks=r'^blocks\.(\d+)', + cls_attn_blocks=[(r'^cls_attn_blocks\.(\d+)', None), (r'^norm', (99999,))] ) @torch.jit.ignore