Fix correctness of some group matching regex (no impact on result), some formatting, missed forward_head for resnet

pull/1014/head
Ross Wightman 3 years ago
parent 94bcdebd73
commit 0862e6ebae

@ -1004,7 +1004,8 @@ class BottleneckBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - kxk - 1x1
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(BottleneckBlock, self).__init__()
@ -1061,7 +1062,8 @@ class DarkBlock(nn.Module):
for more optimal compute.
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
super(DarkBlock, self).__init__()
@ -1111,7 +1113,8 @@ class EdgeBlock(nn.Module):
FIXME is there a more common 3x3 + 1x1 conv block to name this after?
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
super(EdgeBlock, self).__init__()
@ -1158,7 +1161,8 @@ class RepVggBlock(nn.Module):
This version does not currently support the deploy optimization. It is currently fixed in 'train' mode.
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(RepVggBlock, self).__init__()
layers = layers or LayerFn()
@ -1522,7 +1526,7 @@ class ByobNet(nn.Module):
matcher = dict(
stem=r'^stem',
blocks=[
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).(\d+)', None),
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
(r'^final_conv', (99999,))
]
)

@ -164,7 +164,8 @@ class CrossAttention(nn.Module):
class CrossAttentionBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)

@ -157,7 +157,8 @@ class ResBottleneck(nn.Module):
""" ResNe(X)t Bottleneck Block
"""
def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1,
def __init__(
self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(ResBottleneck, self).__init__()
@ -199,7 +200,8 @@ class DarkBlock(nn.Module):
""" DarkNet Block
"""
def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1,
def __init__(
self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None,
drop_block=None, drop_path=None):
super(DarkBlock, self).__init__()
@ -229,7 +231,8 @@ class DarkBlock(nn.Module):
class CrossStage(nn.Module):
"""Cross Stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1.,
def __init__(
self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1.,
groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None,
block_fn=ResBottleneck, **block_kwargs):
super(CrossStage, self).__init__()
@ -280,7 +283,8 @@ class CrossStage(nn.Module):
class DarkStage(nn.Module):
"""DarkNet stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1,
def __init__(
self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1,
first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs):
super(DarkStage, self).__init__()
first_dilation = first_dilation or dilation
@ -387,10 +391,10 @@ class CspNet(nn.Module):
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else [
(r'^stages.(\d+).blocks.(\d+)', None),
(r'^stages.(\d+).*transition', MATCH_PREV_GROUP), # map to last block in stage
(r'^stages.(\d+)', (0,)),
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^stages\.(\d+)\..*transition', MATCH_PREV_GROUP), # map to last block in stage
(r'^stages\.(\d+)', (0,)),
]
)
return matcher

@ -85,7 +85,7 @@ class VisionTransformerDistilled(VisionTransformer):
return dict(
stem=r'^cls_token|pos_embed|patch_embed|dist_token',
blocks=[
(r'^blocks.(\d+)', None),
(r'^blocks\.(\d+)', None),
(r'^norm', (99999,))] # final norm w/ last block
)

@ -45,7 +45,8 @@ default_cfgs = {
class DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d,
def __init__(
self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
super(DenseLayer, self).__init__()
self.add_module('norm1', norm_layer(num_input_features)),
@ -113,7 +114,8 @@ class DenseLayer(nn.Module):
class DenseBlock(nn.ModuleDict):
_version = 2
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
drop_rate=0., memory_efficient=False):
super(DenseBlock, self).__init__()
for i in range(num_layers):
@ -164,8 +166,8 @@ class DenseNet(nn.Module):
def __init__(
self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg',
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False,
aa_stem_only=True):
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0,
memory_efficient=False, aa_stem_only=True):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(DenseNet, self).__init__()
@ -252,10 +254,10 @@ class DenseNet(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^features.conv[012]|features.norm[012]|features.pool[012]',
blocks=r'^features.(?:denseblock|transition)(\d+)' if coarse else [
(r'^features.denseblock(\d+).denselayer(\d+)', None),
(r'^features.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer
stem=r'^features\.conv[012]|features\.norm[012]|features\.pool[012]',
blocks=r'^features\.(?:denseblock|transition)(\d+)' if coarse else [
(r'^features\.denseblock(\d+)\.denselayer(\d+)', None),
(r'^features\.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer
]
)
return matcher

@ -323,8 +323,8 @@ class DLA(nn.Module):
stem=r'^base_layer',
blocks=r'^level(\d+)' if coarse else [
# an unusual arch, this achieves somewhat more granularity without getting super messy
(r'^level(\d+).tree(\d+)', None),
(r'^level(\d+).root', (2,)),
(r'^level(\d+)\.tree(\d+)', None),
(r'^level(\d+)\.root', (2,)),
(r'^level(\d+)', (1,))
]
)

@ -243,10 +243,10 @@ class DPN(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^features.conv1',
stem=r'^features\.conv1',
blocks=[
(r'^features.conv(\d+)' if coarse else r'^features.conv(\d+)_(\d+)', None),
(r'^features.conv5_bn_ac', (99999,))
(r'^features\.conv(\d+)' if coarse else r'^features\.conv(\d+)_(\d+)', None),
(r'^features\.conv5_bn_ac', (99999,))
]
)
return matcher

@ -518,7 +518,7 @@ class EfficientNet(nn.Module):
return dict(
stem=r'^conv_stem|bn1',
blocks=[
(r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None),
(r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
(r'conv_head|bn2', (99999,))
]
)

@ -193,7 +193,7 @@ class GhostNet(nn.Module):
matcher = dict(
stem=r'^conv_stem|bn1',
blocks=[
(r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None),
(r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
(r'conv_head', (99999,))
]
)

@ -184,7 +184,7 @@ class Xception65(nn.Module):
matcher = dict(
stem=r'^conv[12]|bn[12]',
blocks=[
(r'^mid.block(\d+)', None),
(r'^mid\.block(\d+)', None),
(r'^block(\d+)', None),
(r'^conv[345]|bn[345]', (99,)),
],

@ -686,8 +686,8 @@ class HighResolutionNet(nn.Module):
matcher = dict(
stem=r'^conv[12]|bn[12]',
blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [
(r'^layer(\d+).(\d+)', None),
(r'^stage(\d+).(\d+)', None),
(r'^layer(\d+)\.(\d+)', None),
(r'^stage(\d+)\.(\d+)', None),
(r'^transition(\d+)', (99999,)),
],
)

@ -496,7 +496,7 @@ class Levit(nn.Module):
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))]
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
return matcher

@ -291,7 +291,7 @@ class MlpMixer(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))]
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore

@ -171,7 +171,7 @@ class MobileNetV3(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^conv_stem|bn1',
blocks=r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)'
blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)'
)
@torch.jit.ignore

@ -334,8 +334,8 @@ class Nest(nn.Module):
matcher = dict(
stem=r'^patch_embed', # stem and embed
blocks=[
(r'^levels.(\d+)' if coarse else r'^levels.(\d+).transformer_encoder.(\d+)', None),
(r'^levels.(\d+).(?:pool|pos_embed)', (0,)),
(r'^levels\.(\d+)' if coarse else r'^levels\.(\d+)\.transformer_encoder\.(\d+)', None),
(r'^levels\.(\d+)\.(?:pool|pos_embed)', (0,)),
(r'^norm', (99999,))
]
)

@ -194,7 +194,6 @@ def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', ski
return cfg
model_cfgs = dict(
# NFNet-F models w/ GELU compatible with DeepMind weights
dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
@ -550,7 +549,7 @@ class NormFreeNet(nn.Module):
matcher = dict(
stem=r'^stem',
blocks=[
(r'^stages.(\d+)' if coarse else r'^stages.(\d+).(\d+)', None),
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
(r'^final_conv', (99999,))
]
)

@ -458,7 +458,7 @@ class RegNet(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)',
blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.blocks\.(\d+)',
)
@torch.jit.ignore

@ -315,7 +315,8 @@ def create_aa(aa_layer, channels, stride=2, enable=True):
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(BasicBlock, self).__init__()
@ -379,7 +380,8 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(Bottleneck, self).__init__()
@ -561,48 +563,35 @@ class ResNet(nn.Module):
Parameters
----------
block : Block
Class for the residual block. Options are BasicBlockGl, BottleneckGl.
layers : list of int
Numbers of layers in each block
num_classes : int, default 1000
Number of classification classes.
in_chans : int, default 3
Number of input (color) channels.
cardinality : int, default 1
Number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64
Factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
stem_width : int, default 64
Number of channels in stem convolutions
block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl.
layers : list of int, number of layers in each block
num_classes : int, default 1000, number of classification classes.
in_chans : int, default 3, number of input (color) channels.
output_stride : int, default 32, output stride of the network, 32, 16, or 8.
global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
stem_width : int, default 64, number of channels in stem convolutions
stem_type : str, default ''
The type of stem:
* '', default - a single 7x7 conv with a width of stem_width
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
block_reduce_first: int, default 1
Reduction factor for first convolution output width of residual blocks,
1 for all archs except senets, where 2
down_kernel_size: int, default 1
Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets
avg_down : bool, default False
Whether to use average pooling for projection skip connection between stages/downsample.
output_stride : int, default 32
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
block_reduce_first : int, default 1
Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2
down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets
avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample.
act_layer : nn.Module, activation layer
norm_layer : nn.Module, normalization layer
aa_layer : nn.Module, anti-aliasing layer
drop_rate : float, default 0.
Dropout probability before classifier, for training
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
drop_rate : float, default 0. Dropout probability before classifier, for training
"""
def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False,
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', zero_init_last=True, block_args=None):
def __init__(
self, block, layers, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, block_reduce_first=1,
down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None,
drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., zero_init_last=True, block_args=None):
super(ResNet, self).__init__()
block_args = block_args or dict()
assert output_stride in (8, 16, 32)
@ -712,12 +701,15 @@ class ResNet(nn.Module):
x = self.layer4(x)
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x if pre_logits else self.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x

@ -411,8 +411,8 @@ class ResNetV2(nn.Module):
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else [
(r'^stages.(\d+).blocks.(\d+)', None),
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm', (99999,))
]
)

@ -173,7 +173,7 @@ class ReXNetV1(nn.Module):
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=r'^features.(\d+)',
blocks=r'^features\.(\d+)',
)
return matcher

@ -360,7 +360,7 @@ class SENet(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+).(\d+)')
matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
return matcher
@torch.jit.ignore

@ -525,9 +525,9 @@ class SwinTransformer(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^absolute_pos_embed|patch_embed', # stem and embed
blocks=r'^layers.(\d+)' if coarse else [
(r'^layers.(\d+).downsample', (0,)),
(r'^layers.(\d+).\w+.(\d+)', None),
blocks=r'^layers\.(\d+)' if coarse else [
(r'^layers\.(\d+).downsample', (0,)),
(r'^layers\.(\d+)\.\w+\.(\d+)', None),
(r'^norm', (99999,)),
]
)

@ -217,7 +217,7 @@ class TNT(nn.Module):
matcher = dict(
stem=r'^cls_token|patch_pos|pixel_pos|pixel_embed|norm[12]_proj|proj', # stem and embed / pos
blocks=[
(r'^blocks.(\d+)', None),
(r'^blocks\.(\d+)', None),
(r'^norm', (99999,)),
]
)

@ -233,7 +233,7 @@ class TResNet(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(stem=r'^body.conv1', blocks=r'^body.layer(\d+)' if coarse else r'^body.layer(\d+).(\d+)')
matcher = dict(stem=r'^body\.conv1', blocks=r'^body\.layer(\d+)' if coarse else r'^body\.layer(\d+)\.(\d+)')
return matcher
@torch.jit.ignore

@ -327,11 +327,11 @@ class Twins(nn.Module):
matcher = dict(
stem=r'^patch_embeds.0', # stem and embed
blocks=[
(r'^(?:blocks|patch_embeds|pos_block).(\d+)', None),
(r'^(?:blocks|patch_embeds|pos_block)\.(\d+)', None),
('^norm', (99999,))
] if coarse else [
(r'^blocks.(\d+).(\d+)', None),
(r'^(?:patch_embeds|pos_block).(\d+)', (0,)),
(r'^blocks\.(\d+)\.(\d+)', None),
(r'^(?:patch_embeds|pos_block)\.(\d+)', (0,)),
(r'^norm', (99999,))
]
)

@ -136,7 +136,7 @@ class VGG(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
# this treats BN layers as separate groups for bn variants, a lot of effort to fix that
return dict(stem=r'^features.0', blocks=r'^features.(\d+)')
return dict(stem=r'^features\.0', blocks=r'^features\.(\d+)')
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):

@ -271,7 +271,7 @@ class Visformer(nn.Module):
return dict(
stem=r'^patch_embed1|pos_embed1|stem', # stem and embed
blocks=[
(r'^stage(\d+).(\d+)' if coarse else r'^stage(\d+).(\d+)', None),
(r'^stage(\d+)\.(\d+)' if coarse else r'^stage(\d+)\.(\d+)', None),
(r'^(?:patch_embed|pos_embed)(\d+)', (0,)),
(r'^norm', (99999,))
]

@ -331,7 +331,7 @@ class VisionTransformer(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))]
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore

@ -327,7 +327,7 @@ class VovNet(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)',
blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)',
)
@torch.jit.ignore

@ -221,7 +221,7 @@ class XceptionAligned(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^blocks.(\d+)',
blocks=r'^blocks\.(\d+)',
)
@torch.jit.ignore

@ -412,8 +412,8 @@ class XCiT(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=r'^blocks.(\d+)',
cls_attn_blocks=[(r'^cls_attn_blocks.(\d+)', None), (r'^norm', (99999,))]
blocks=r'^blocks\.(\d+)',
cls_attn_blocks=[(r'^cls_attn_blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore

Loading…
Cancel
Save