From 8e4ac3549f65eefa6b094cd04876b19ed3ca7506 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 7 Jun 2021 17:14:19 -0700 Subject: [PATCH 01/16] All ScaledStdConv and StdConv uses default to using F.layernorm so that they work with PyTorch XLA. eps value tweaking is a WIP. --- timm/models/layers/std_conv.py | 52 +++++++++++++----------- timm/models/nfnet.py | 7 +++- timm/models/vision_transformer_hybrid.py | 2 +- 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py index b0cb1eeb..a1afc653 100644 --- a/timm/models/layers/std_conv.py +++ b/timm/models/layers/std_conv.py @@ -19,17 +19,22 @@ class StdConv2d(nn.Conv2d): """ def __init__( self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1, - groups=1, bias=False, eps=1e-5): + groups=1, bias=False, eps=1e-5, use_layernorm=True): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.eps = eps + self.use_layernorm = use_layernorm def get_weight(self): - std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = (self.weight - mean) / (std + self.eps) + if self.use_layernorm: + # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op + weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + else: + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = (self.weight - mean) / (std + self.eps) return weight def forward(self, x): @@ -45,17 +50,22 @@ class StdConv2dSame(nn.Conv2d): """ def __init__( self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, - groups=1, bias=False, eps=1e-5): + groups=1, bias=False, eps=1e-5, use_layernorm=True): padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.same_pad = is_dynamic self.eps = eps + self.use_layernorm = use_layernorm def get_weight(self): - std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = (self.weight - mean) / (std + self.eps) + if self.use_layernorm: + # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op + weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + else: + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = (self.weight - mean) / (std + self.eps) return weight def forward(self, x): @@ -76,7 +86,7 @@ class ScaledStdConv2d(nn.Conv2d): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, - bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False): + bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( @@ -84,16 +94,17 @@ class ScaledStdConv2d(nn.Conv2d): groups=groups, bias=bias) self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) - self.eps = eps ** 2 if use_layernorm else eps + self.eps = eps self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel def get_weight(self): if self.use_layernorm: - weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op + weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = self.scale * (self.weight - mean) / (std + self.eps) - return self.gain * weight + weight = (self.weight - mean) / (std + self.eps) + return weight.mul_(self.gain * self.scale) def forward(self, x): return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) @@ -110,7 +121,7 @@ class ScaledStdConv2dSame(nn.Conv2d): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1, - bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False): + bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True): padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, @@ -118,24 +129,17 @@ class ScaledStdConv2dSame(nn.Conv2d): self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) self.scale = gamma * self.weight[0].numel() ** -0.5 self.same_pad = is_dynamic - self.eps = eps ** 2 if use_layernorm else eps + self.eps = eps self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel - # NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem - # to make much numerical difference (+/- .002 to .004) in top-1 during eval. - # def get_weight(self): - # var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - # scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain - # weight = (self.weight - mean) * scale - # return self.gain * weight - def get_weight(self): if self.use_layernorm: - weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op + weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = self.scale * (self.weight - mean) / (std + self.eps) - return self.gain * weight + weight = (self.weight - mean) / (std + self.eps) + return weight.mul_(self.gain * self.scale) def forward(self, x): if self.same_pad: diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 593796a5..584495c3 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -166,6 +166,8 @@ class NfCfg: extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models gamma_in_act: bool = False same_padding: bool = False + std_conv_eps: float = 1e-5 + std_conv_ln: bool = True # use layer-norm impl to normalize in std-conv, works in PyTorch XLA, slightly faster skipinit: bool = False # disabled by default, non-trivial performance impact zero_init_fc: bool = False act_layer: str = 'silu' @@ -482,10 +484,11 @@ class NormFreeNet(nn.Module): conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d if cfg.gamma_in_act: act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer]) - conv_layer = partial(conv_layer, eps=1e-4) # DM weights better with higher eps + conv_layer = partial(conv_layer, eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln) else: act_layer = get_act_layer(cfg.act_layer) - conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer]) + conv_layer = partial( + conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln) attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 9e5a62b2..a32ce019 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -118,7 +118,7 @@ def _resnetv2(layers=(3, 4, 9), **kwargs): padding_same = kwargs.get('padding_same', True) if padding_same: stem_type = 'same' - conv_layer = StdConv2dSame + conv_layer = partial(StdConv2dSame, eps=1e-5) else: stem_type = '' conv_layer = StdConv2d From 2f5ed2dec1e5b020bfd0c4271845e2288a223624 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 7 Jun 2021 17:15:04 -0700 Subject: [PATCH 02/16] Update `init_values` const for 24 and 36 layer ResMLP models --- timm/models/mlp_mixer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 5a6dce6f..6f53264a 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -422,7 +422,8 @@ def resmlp_24_224(pretrained=False, **kwargs): Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ model_args = dict( - patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args) return model @@ -433,7 +434,8 @@ def resmlp_36_224(pretrained=False, **kwargs): Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ model_args = dict( - patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args) return model From ba2ca4b46440c9fcf579fc66ca6df3082db44475 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 12 Jun 2021 12:27:43 -0700 Subject: [PATCH 03/16] One codepath for stdconv, switch layernorm to batchnorm so gain included. Tweak epsilon values for nfnet, resnetv2, vit hybrid. --- timm/models/layers/std_conv.py | 78 ++++++++---------------- timm/models/nfnet.py | 6 +- timm/models/resnetv2.py | 6 +- timm/models/vision_transformer_hybrid.py | 8 +-- 4 files changed, 33 insertions(+), 65 deletions(-) diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py index a1afc653..49b35875 100644 --- a/timm/models/layers/std_conv.py +++ b/timm/models/layers/std_conv.py @@ -18,27 +18,20 @@ class StdConv2d(nn.Conv2d): https://arxiv.org/abs/1903.10520v2 """ def __init__( - self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1, - groups=1, bias=False, eps=1e-5, use_layernorm=True): + self, in_channel, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=False, eps=1e-6): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.eps = eps - self.use_layernorm = use_layernorm - - def get_weight(self): - if self.use_layernorm: - # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op - weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) - else: - std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = (self.weight - mean) / (std + self.eps) - return weight def forward(self, x): - x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + eps=self.eps, training=True, momentum=0.).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -49,29 +42,22 @@ class StdConv2dSame(nn.Conv2d): https://arxiv.org/abs/1903.10520v2 """ def __init__( - self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, - groups=1, bias=False, eps=1e-5, use_layernorm=True): + self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=False, eps=1e-6): padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.same_pad = is_dynamic self.eps = eps - self.use_layernorm = use_layernorm - - def get_weight(self): - if self.use_layernorm: - # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op - weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) - else: - std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = (self.weight - mean) / (std + self.eps) - return weight def forward(self, x): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) - x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + eps=self.eps, training=True, momentum=0.).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -85,8 +71,8 @@ class ScaledStdConv2d(nn.Conv2d): """ def __init__( - self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, - bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True): + self, in_channels, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( @@ -95,19 +81,13 @@ class ScaledStdConv2d(nn.Conv2d): self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) self.eps = eps - self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel - - def get_weight(self): - if self.use_layernorm: - # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op - weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) - else: - std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = (self.weight - mean) / (std + self.eps) - return weight.mul_(self.gain * self.scale) def forward(self, x): - return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + eps=self.eps, training=True, momentum=0.).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) class ScaledStdConv2dSame(nn.Conv2d): @@ -120,8 +100,8 @@ class ScaledStdConv2dSame(nn.Conv2d): """ def __init__( - self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1, - bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True): + self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, @@ -130,18 +110,12 @@ class ScaledStdConv2dSame(nn.Conv2d): self.scale = gamma * self.weight[0].numel() ** -0.5 self.same_pad = is_dynamic self.eps = eps - self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel - - def get_weight(self): - if self.use_layernorm: - # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op - weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) - else: - std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = (self.weight - mean) / (std + self.eps) - return weight.mul_(self.gain * self.scale) def forward(self, x): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) - return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + eps=self.eps, training=True, momentum=0.).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 584495c3..fc0a20c2 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -167,7 +167,6 @@ class NfCfg: gamma_in_act: bool = False same_padding: bool = False std_conv_eps: float = 1e-5 - std_conv_ln: bool = True # use layer-norm impl to normalize in std-conv, works in PyTorch XLA, slightly faster skipinit: bool = False # disabled by default, non-trivial performance impact zero_init_fc: bool = False act_layer: str = 'silu' @@ -484,11 +483,10 @@ class NormFreeNet(nn.Module): conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d if cfg.gamma_in_act: act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer]) - conv_layer = partial(conv_layer, eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln) + conv_layer = partial(conv_layer, eps=cfg.std_conv_eps) else: act_layer = get_act_layer(cfg.act_layer) - conv_layer = partial( - conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln) + conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps) attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 0ca6fba9..250695a8 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -276,7 +276,7 @@ class ResNetStage(nn.Module): def create_resnetv2_stem( in_chs, out_chs=64, stem_type='', preact=True, - conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)): + conv_layer=partial(StdConv2d, eps=1e-8), norm_layer=partial(GroupNormAct, num_groups=32)): stem = OrderedDict() assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same') @@ -315,8 +315,8 @@ class ResNetV2(nn.Module): def __init__(self, layers, channels=(256, 512, 1024, 2048), num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, - act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), - drop_rate=0., drop_path_rate=0.): + act_layer=nn.ReLU, conv_layer=partial(StdConv2d, eps=1e-8), + norm_layer=partial(GroupNormAct, num_groups=32), drop_rate=0., drop_path_rate=0.): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index a32ce019..7fc0cc88 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -116,12 +116,8 @@ def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwa def _resnetv2(layers=(3, 4, 9), **kwargs): """ ResNet-V2 backbone helper""" padding_same = kwargs.get('padding_same', True) - if padding_same: - stem_type = 'same' - conv_layer = partial(StdConv2dSame, eps=1e-5) - else: - stem_type = '' - conv_layer = StdConv2d + stem_type = 'same' if padding_same else '' + conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8) if len(layers): backbone = ResNetV2( layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), From 8880f696b6b8368a76296126476ea020fc7c814c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 12 Jun 2021 16:40:02 -0700 Subject: [PATCH 04/16] Refactoring, cleanup, improved test coverage. * Add eca_nfnet_l2 weights, 84.7 @ 384x384 * All 'non-std' (ie transformer / mlp) models have classifier / default_cfg test added * Fix #694 reset_classifer / num_features / forward_features / num_classes=0 consistency for transformer / mlp models * Add direct loading of npz to vision transformer (pure transformer so far, hybrid to come) * Rename vit_deit* to deit_* * Remove some deprecated vit hybrid model defs * Clean up classifier flatten for conv classifiers and unusual cases (mobilenetv3/ghostnet) * Remove explicit model fns for levit conv, just pass in arg --- tests/test_models.py | 55 ++++- timm/models/cait.py | 30 ++- timm/models/coat.py | 8 +- timm/models/convit.py | 23 +- timm/models/dla.py | 6 +- timm/models/dpn.py | 5 +- timm/models/ghostnet.py | 9 +- timm/models/helpers.py | 29 +++ timm/models/layers/adaptive_avgmax_pool.py | 13 +- timm/models/layers/classifier.py | 5 +- timm/models/layers/mlp.py | 6 + timm/models/levit.py | 87 ++++--- timm/models/mlp_mixer.py | 116 +++++---- timm/models/mobilenetv3.py | 5 +- timm/models/nfnet.py | 7 +- timm/models/pit.py | 32 ++- timm/models/registry.py | 13 +- timm/models/resnet.py | 7 +- timm/models/resnetv2.py | 266 ++++++++++----------- timm/models/swin_transformer.py | 43 ++-- timm/models/twins.py | 8 +- timm/models/visformer.py | 63 ++--- timm/models/vision_transformer.py | 228 ++++++++++++------ timm/models/vision_transformer_hybrid.py | 28 +-- 24 files changed, 637 insertions(+), 455 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 5a31935e..ac156806 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,7 +17,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*'] + 'convit_*', 'levit*', 'visformer*', 'deit*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures @@ -120,7 +120,6 @@ def test_model_default_cfgs(model_name, batch_size): state_dict = model.state_dict() cfg = model.default_cfg - classifier = cfg['classifier'] pool_size = cfg['pool_size'] input_size = model.default_cfg['input_size'] @@ -149,7 +148,57 @@ def test_model_default_cfgs(model_name, batch_size): assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] # check classifier name matches default_cfg - assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params' + classifier = cfg['classifier'] + if not isinstance(classifier, (tuple, list)): + classifier = classifier, + for c in classifier: + assert c + ".weight" in state_dict.keys(), f'{c} not in model params' + + # check first conv(s) names match default_cfg + first_conv = cfg['first_conv'] + if isinstance(first_conv, str): + first_conv = (first_conv,) + assert isinstance(first_conv, (tuple, list)) + for fc in first_conv: + assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params' + + +@pytest.mark.timeout(300) +@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_default_cfgs_non_std(model_name, batch_size): + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False) + model.eval() + state_dict = model.state_dict() + cfg = model.default_cfg + + input_size = _get_input_size(model_name=model_name, target=TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + input_tensor = torch.randn((batch_size, *input_size)) + + # test forward_features (always unpooled) + outputs = model.forward_features(input_tensor) + if isinstance(outputs, tuple): + outputs = outputs[0] + assert outputs.shape[1] == model.num_features + + # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features + model.reset_classifier(0) + outputs = model.forward(input_tensor) + if isinstance(outputs, tuple): + outputs = outputs[0] + assert len(outputs.shape) == 2 + assert outputs.shape[1] == model.num_features + + # check classifier name matches default_cfg + classifier = cfg['classifier'] + if not isinstance(classifier, (tuple, list)): + classifier = classifier, + for c in classifier: + assert c + ".weight" in state_dict.keys(), f'{c} not in model params' # check first conv(s) names match default_cfg first_conv = cfg['first_conv'] diff --git a/timm/models/cait.py b/timm/models/cait.py index aa2e5f07..69b4ba06 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -74,11 +74,11 @@ default_cfgs = dict( class ClassAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to do CA - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.k = nn.Linear(dim, dim, bias=qkv_bias) @@ -110,13 +110,13 @@ class LayerScaleBlockClassAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add CA and LayerScale def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn, mlp_block=Mlp, init_values=1e-4): super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_block( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -134,14 +134,14 @@ class LayerScaleBlockClassAttn(nn.Module): class TalkingHeadAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf) - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -177,13 +177,13 @@ class LayerScaleBlock(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add layerScale def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn, mlp_block=Mlp, init_values=1e-4): super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_block( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -202,7 +202,7 @@ class Cait(nn.Module): # with slight modifications to adapt to our cait models def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., + num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), global_pool=None, @@ -235,14 +235,14 @@ class Cait(nn.Module): dpr = [drop_path_rate for i in range(depth)] self.blocks = nn.ModuleList([ block_layers( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale) for i in range(depth)]) self.blocks_token_only = nn.ModuleList([ block_layers_token( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer, act_layer=act_layer, attn_block=attn_block_token_only, mlp_block=mlp_block_token_only, init_values=init_scale) @@ -270,6 +270,13 @@ class Cait(nn.Module): def no_weight_decay(self): return {'pos_embed', 'cls_token'} + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) @@ -293,7 +300,6 @@ class Cait(nn.Module): def forward(self, x): x = self.forward_features(x) x = self.head(x) - return x diff --git a/timm/models/coat.py b/timm/models/coat.py index 9eb384d8..f071715a 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -335,6 +335,8 @@ class CoaT(nn.Module): crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} self.return_interm_layers = return_interm_layers self.out_features = out_features + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] self.num_classes = num_classes # Patch embeddings. @@ -441,10 +443,10 @@ class CoaT(nn.Module): # CoaT series: Aggregate features of last three scales for classification. assert embed_dims[1] == embed_dims[2] == embed_dims[3] self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) - self.head = nn.Linear(embed_dims[3], num_classes) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() else: # CoaT-Lite series: Use feature of last scale for classification. - self.head = nn.Linear(embed_dims[3], num_classes) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # Initialize weights. trunc_normal_(self.cls_token1, std=.02) @@ -471,7 +473,7 @@ class CoaT(nn.Module): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def insert_cls(self, x, cls_token): """ Insert CLS token. """ diff --git a/timm/models/convit.py b/timm/models/convit.py index b15b46d8..0593ec1c 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -57,13 +57,13 @@ default_cfgs = { class GPSA(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.): super().__init__() self.num_heads = num_heads self.dim = dim head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.locality_strength = locality_strength self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) @@ -142,11 +142,11 @@ class GPSA(nn.Module): class MHSA(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -191,19 +191,16 @@ class MHSA(nn.Module): class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): super().__init__() self.norm1 = norm_layer(dim) self.use_gpsa = use_gpsa if self.use_gpsa: self.attn = GPSA( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, - proj_drop=drop, **kwargs) + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs) else: - self.attn = MHSA( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, - proj_drop=drop, **kwargs) + self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -220,7 +217,7 @@ class ConViT(nn.Module): """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, local_up_to_layer=3, locality_strength=1., use_pos_embed=True): super().__init__() @@ -250,13 +247,13 @@ class ConViT(nn.Module): dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, use_gpsa=True, locality_strength=locality_strength) if i < local_up_to_layer else Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, use_gpsa=False) for i in range(depth)]) diff --git a/timm/models/dla.py b/timm/models/dla.py index f0f25b0b..f6e4dd28 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -288,6 +288,8 @@ class DLA(nn.Module): self.num_features = channels[-1] self.global_pool, self.fc = create_classifier( self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels @@ -314,6 +316,7 @@ class DLA(nn.Module): self.num_classes = num_classes self.global_pool, self.fc = create_classifier( self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() def forward_features(self, x): x = self.base_layer(x) @@ -331,8 +334,7 @@ class DLA(nn.Module): if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) - if not self.global_pool.is_identity(): - x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled) + x = self.flatten(x) return x diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 90ef11cc..c4e380b1 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -237,6 +237,7 @@ class DPN(nn.Module): # Using 1x1 conv for the FC layer to allow the extra pooling scheme self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() def get_classifier(self): return self.classifier @@ -245,6 +246,7 @@ class DPN(nn.Module): self.num_classes = num_classes self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() def forward_features(self, x): return self.features(x) @@ -255,8 +257,7 @@ class DPN(nn.Module): if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.classifier(x) - if not self.global_pool.is_identity(): - x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled) + x = self.flatten(x) return x diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 48dee6ec..a73047c5 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -133,7 +133,7 @@ class GhostBottleneck(nn.Module): class GhostNet(nn.Module): - def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32): + def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'): super(GhostNet, self).__init__() # setting of inverted residual blocks assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported' @@ -178,9 +178,10 @@ class GhostNet(nn.Module): # building last several layers self.num_features = out_chs = 1280 - self.global_pool = SelectAdaptivePool2d(pool_type='avg') + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) self.act2 = nn.ReLU(inplace=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(out_chs, num_classes) def get_classifier(self): @@ -190,6 +191,7 @@ class GhostNet(nn.Module): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): @@ -204,8 +206,7 @@ class GhostNet(nn.Module): def forward(self, x): x = self.forward_features(x) - if not self.global_pool.is_identity(): - x = x.view(x.size(0), -1) + x = self.flatten(x) if self.dropout > 0.: x = F.dropout(x, p=self.dropout, training=self.training) x = self.classifier(x) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index adfef550..662a7a48 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -45,6 +45,13 @@ def load_state_dict(checkpoint_path, use_ema=False): def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): + if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, 'load_pretrained'): + model.load_pretrained(checkpoint_path) + else: + raise NotImplementedError('Model cannot load numpy checkpoint') + return state_dict = load_state_dict(checkpoint_path, use_ema) model.load_state_dict(state_dict, strict=strict) @@ -477,3 +484,25 @@ def model_parameters(model, exclude_head=False): return [p for p in model.parameters()][:-2] else: return model.parameters() + + +def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): + if not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + yield name, module diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/models/layers/adaptive_avgmax_pool.py index d2bb9f72..ebc6ada8 100644 --- a/timm/models/layers/adaptive_avgmax_pool.py +++ b/timm/models/layers/adaptive_avgmax_pool.py @@ -55,7 +55,7 @@ class FastAdaptiveAvgPool2d(nn.Module): self.flatten = flatten def forward(self, x): - return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True) + return x.mean((2, 3), keepdim=not self.flatten) class AdaptiveAvgMaxPool2d(nn.Module): @@ -82,13 +82,13 @@ class SelectAdaptivePool2d(nn.Module): def __init__(self, output_size=1, pool_type='fast', flatten=False): super(SelectAdaptivePool2d, self).__init__() self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing - self.flatten = flatten + self.flatten = nn.Flatten(1) if flatten else nn.Identity() if pool_type == '': self.pool = nn.Identity() # pass through elif pool_type == 'fast': assert output_size == 1 - self.pool = FastAdaptiveAvgPool2d(self.flatten) - self.flatten = False + self.pool = FastAdaptiveAvgPool2d(flatten) + self.flatten = nn.Identity() elif pool_type == 'avg': self.pool = nn.AdaptiveAvgPool2d(output_size) elif pool_type == 'avgmax': @@ -101,12 +101,11 @@ class SelectAdaptivePool2d(nn.Module): assert False, 'Invalid pool type: %s' % pool_type def is_identity(self): - return self.pool_type == '' + return not self.pool_type def forward(self, x): x = self.pool(x) - if self.flatten: - x = x.flatten(1) + x = self.flatten(x) return x def feat_mult(self): diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py index 516cc6c9..2b745413 100644 --- a/timm/models/layers/classifier.py +++ b/timm/models/layers/classifier.py @@ -20,7 +20,7 @@ def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): return global_pool, num_pooled_features -def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False): +def _create_fc(num_features, num_classes, use_conv=False): if num_classes <= 0: fc = nn.Identity() # pass-through (no classifier) elif use_conv: @@ -45,11 +45,12 @@ class ClassifierHead(nn.Module): self.drop_rate = drop_rate self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) - self.flatten_after_fc = use_conv and pool_type + self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() def forward(self, x): x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = self.fc(x) + x = self.flatten(x) return x diff --git a/timm/models/layers/mlp.py b/timm/models/layers/mlp.py index 4739ba74..05d07652 100644 --- a/timm/models/layers/mlp.py +++ b/timm/models/layers/mlp.py @@ -40,6 +40,12 @@ class GluMlp(nn.Module): self.fc2 = nn.Linear(hidden_features // 2, out_features) self.drop = nn.Dropout(drop) + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + fc1_mid = self.fc1.bias.shape[0] // 2 + nn.init.ones_(self.fc1.bias[fc1_mid:]) + nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + def forward(self, x): x = self.fc1(x) x, gates = x.chunk(2, dim=-1) diff --git a/timm/models/levit.py b/timm/models/levit.py index 2180254a..fa35f41f 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -84,63 +84,33 @@ __all__ = ['Levit'] @register_model -def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs): +def levit_128s(pretrained=False, use_conv=False, **kwargs): return create_levit( - 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + 'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs) @register_model -def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): +def levit_128(pretrained=False, use_conv=False, **kwargs): return create_levit( - 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + 'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs) @register_model -def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): +def levit_192(pretrained=False, use_conv=False, **kwargs): return create_levit( - 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + 'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs) @register_model -def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): +def levit_256(pretrained=False, use_conv=False, **kwargs): return create_levit( - 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + 'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs) @register_model -def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): +def levit_384(pretrained=False, use_conv=False, **kwargs): return create_levit( - 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) - - -@register_model -def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs): - return create_levit( - 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) - - -@register_model -def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs): - return create_levit( - 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) - - -@register_model -def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): - return create_levit( - 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) - - -@register_model -def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): - return create_levit( - 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) - - -@register_model -def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): - return create_levit( - 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + 'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs) class ConvNorm(nn.Sequential): @@ -427,6 +397,9 @@ class AttentionSubsample(nn.Module): class Levit(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage + + NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems + w/ train scripts that don't take tuple outputs, """ def __init__( @@ -447,7 +420,8 @@ class Levit(nn.Module): attn_act_layer='hard_swish', distillation=True, use_conv=False, - drop_path=0): + drop_rate=0., + drop_path_rate=0.): super().__init__() act_layer = get_act_layer(act_layer) attn_act_layer = get_act_layer(attn_act_layer) @@ -486,7 +460,7 @@ class Levit(nn.Module): Attention( ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution, use_conv=use_conv), - drop_path)) + drop_path_rate)) if mr > 0: h = int(ed * mr) self.blocks.append( @@ -494,7 +468,7 @@ class Levit(nn.Module): ln_layer(ed, h, resolution=resolution), act_layer(), ln_layer(h, ed, bn_weight_init=0, resolution=resolution), - ), drop_path)) + ), drop_path_rate)) if do[0] == 'Subsample': # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) resolution_ = (resolution - 1) // do[5] + 1 @@ -511,26 +485,45 @@ class Levit(nn.Module): ln_layer(embed_dim[i + 1], h, resolution=resolution), act_layer(), ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), - ), drop_path)) + ), drop_path_rate)) self.blocks = nn.Sequential(*self.blocks) # Classifier head self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None if distillation: self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() - else: - self.head_dist = None @torch.jit.ignore def no_weight_decay(self): return {x for x in self.state_dict().keys() if 'attention_biases' in x} - def forward(self, x): + def get_classifier(self): + if self.head_dist is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool='', distillation=None): + self.num_classes = num_classes + self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + if distillation is not None: + self.distillation = distillation + if self.distillation: + self.head_dist = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + else: + self.head_dist = None + + def forward_features(self, x): x = self.patch_embed(x) if not self.use_conv: x = x.flatten(2).transpose(1, 2) x = self.blocks(x) x = x.mean((-2, -1)) if self.use_conv else x.mean(1) + return x + + def forward(self, x): + x = self.forward_features(x) if self.head_dist is not None: x, x_dist = self.head(x), self.head_dist(x) if self.training and not torch.jit.is_scripting(): diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 6f53264a..ea6de824 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -45,7 +45,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple from .registry import register_model @@ -169,6 +169,11 @@ class SpatialGatingUnit(nn.Module): self.norm = norm_layer(gate_dim) self.proj = nn.Linear(seq_len, seq_len) + def init_weights(self): + # special init for the projection gate, called as override by base model init + nn.init.normal_(self.proj.weight, std=1e-6) + nn.init.ones_(self.proj.bias) + def forward(self, x): u, v = x.chunk(2, dim=-1) v = self.norm(v) @@ -205,7 +210,7 @@ class MlpMixer(nn.Module): in_chans=3, patch_size=16, num_blocks=8, - hidden_dim=512, + embed_dim=512, mlp_ratio=(0.5, 4.0), block_layer=MixerBlock, mlp_layer=Mlp, @@ -218,59 +223,71 @@ class MlpMixer(nn.Module): ): super().__init__() self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.stem = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim, - norm_layer=norm_layer if stem_norm else None) + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None) # FIXME drop_path (stochastic depth scaling rule or all the same?) self.blocks = nn.Sequential(*[ block_layer( - hidden_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, + embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) for _ in range(num_blocks)]) - self.norm = norm_layer(hidden_dim) - self.head = nn.Linear(hidden_dim, self.num_classes) # zero init + self.norm = norm_layer(embed_dim) + self.head = nn.Linear(embed_dim, self.num_classes) # zero init self.init_weights(nlhb=nlhb) def init_weights(self, nlhb=False): head_bias = -math.log(self.num_classes) if nlhb else 0. - for n, m in self.named_modules(): - _init_weights(m, n, head_bias=head_bias) + named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first - def forward(self, x): + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): x = self.stem(x) x = self.blocks(x) x = self.norm(x) x = x.mean(dim=1) + return x + + def forward(self, x): + x = self.forward_features(x) x = self.head(x) return x -def _init_weights(m, n: str, head_bias: float = 0.): +def _init_weights(module: nn.Module, name: str, head_bias: float = 0.): """ Mixer weight initialization (trying to match Flax defaults) """ - if isinstance(m, nn.Linear): - if n.startswith('head'): - nn.init.zeros_(m.weight) - nn.init.constant_(m.bias, head_bias) - elif n.endswith('gate.proj'): - nn.init.normal_(m.weight, std=1e-4) - nn.init.ones_(m.bias) + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) else: - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - if 'mlp' in n: - nn.init.normal_(m.bias, std=1e-6) + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) else: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Conv2d): - lecun_normal_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.LayerNorm): - nn.init.zeros_(m.bias) - nn.init.ones_(m.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + # NOTE if a parent module contains init_weights method, it can override the init of the + # child modules as this will be called in depth-first order. + module.init_weights() def _create_mixer(variant, pretrained=False, **kwargs): @@ -289,7 +306,7 @@ def mixer_s32_224(pretrained=False, **kwargs): """ Mixer-S/32 224x224 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, **kwargs) + model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs) model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args) return model @@ -299,7 +316,7 @@ def mixer_s16_224(pretrained=False, **kwargs): """ Mixer-S/16 224x224 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, **kwargs) + model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs) model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args) return model @@ -309,7 +326,7 @@ def mixer_b32_224(pretrained=False, **kwargs): """ Mixer-B/32 224x224 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, **kwargs) + model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs) model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args) return model @@ -319,7 +336,7 @@ def mixer_b16_224(pretrained=False, **kwargs): """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args) return model @@ -329,7 +346,7 @@ def mixer_b16_224_in21k(pretrained=False, **kwargs): """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args) return model @@ -339,7 +356,7 @@ def mixer_l32_224(pretrained=False, **kwargs): """ Mixer-L/32 224x224. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, **kwargs) + model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs) model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args) return model @@ -349,7 +366,7 @@ def mixer_l16_224(pretrained=False, **kwargs): """ Mixer-L/16 224x224. ImageNet-1k pretrained weights. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs) + model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs) model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args) return model @@ -359,35 +376,38 @@ def mixer_l16_224_in21k(pretrained=False, **kwargs): """ Mixer-L/16 224x224. ImageNet-21k pretrained weights. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs) + model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs) model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args) return model + @register_model def mixer_b16_224_miil(pretrained=False, **kwargs): """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args) return model + @register_model def mixer_b16_224_miil_in21k(pretrained=False, **kwargs): """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args) return model + @register_model def gmixer_12_224(pretrained=False, **kwargs): """ Glu-Mixer-12 224x224 (short & fat) Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer """ model_args = dict( - patch_size=20, num_blocks=12, hidden_dim=512, mlp_ratio=(1.0, 6.0), + patch_size=16, num_blocks=12, embed_dim=512, mlp_ratio=(1.0, 6.0), mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args) return model @@ -399,7 +419,7 @@ def gmixer_24_224(pretrained=False, **kwargs): Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer """ model_args = dict( - patch_size=20, num_blocks=24, hidden_dim=384, mlp_ratio=(1.0, 6.0), + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 6.0), mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args) return model @@ -411,7 +431,7 @@ def resmlp_12_224(pretrained=False, **kwargs): Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ model_args = dict( - patch_size=16, num_blocks=12, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args) return model @@ -422,7 +442,7 @@ def resmlp_24_224(pretrained=False, **kwargs): Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ model_args = dict( - patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4, block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args) return model @@ -434,7 +454,7 @@ def resmlp_36_224(pretrained=False, **kwargs): Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ model_args = dict( - patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, + patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4, block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args) return model @@ -446,7 +466,7 @@ def gmlp_ti16_224(pretrained=False, **kwargs): Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 """ model_args = dict( - patch_size=16, num_blocks=30, hidden_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock, + patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock, mlp_layer=GatedMlp, **kwargs) model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args) return model @@ -458,7 +478,7 @@ def gmlp_s16_224(pretrained=False, **kwargs): Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 """ model_args = dict( - patch_size=16, num_blocks=30, hidden_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock, + patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock, mlp_layer=GatedMlp, **kwargs) model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args) return model @@ -470,7 +490,7 @@ def gmlp_b16_224(pretrained=False, **kwargs): Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 """ model_args = dict( - patch_size=16, num_blocks=30, hidden_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock, + patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock, mlp_layer=GatedMlp, **kwargs) model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args) return model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index e85112e6..f810eb82 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -119,6 +119,7 @@ class MobileNetV3(nn.Module): num_pooled_chs = head_chs * self.global_pool.feat_mult() self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) self.act2 = act_layer(inplace=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() efficientnet_init_weights(self) @@ -137,6 +138,7 @@ class MobileNetV3(nn.Module): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): @@ -151,8 +153,7 @@ class MobileNetV3(nn.Module): def forward(self, x): x = self.forward_features(x) - if not self.global_pool.is_identity(): - x = x.flatten(1) + x = self.flatten(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index fc0a20c2..4e0f2b21 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -111,11 +111,11 @@ default_cfgs = dict( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), eca_nfnet_l2=_dcfg( - url='', - pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 352, 352), crop_pct=1.0), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0), eca_nfnet_l3=_dcfg( url='', - pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0), + pool_size=(11, 11), input_size=(3, 352, 352), test_input_size=(3, 448, 448), crop_pct=1.0), nf_regnet_b0=_dcfg( url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), @@ -210,6 +210,7 @@ def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', ski return cfg + model_cfgs = dict( # NFNet-F models w/ GELU compatible with DeepMind weights dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)), diff --git a/timm/models/pit.py b/timm/models/pit.py index 9c350861..460824e2 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -186,12 +186,13 @@ class PoolingVisionTransformer(nn.Module): ] self.transformers = SequentialTuple(*transformers) self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) - self.embed_dim = base_dims[-1] * heads[-1] + self.num_features = self.embed_dim = base_dims[-1] * heads[-1] # Classifier head self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ - if num_classes > 0 and distilled else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) @@ -207,13 +208,16 @@ class PoolingVisionTransformer(nn.Module): return {'pos_embed', 'cls_token'} def get_classifier(self): - return self.head + if self.head_dist is not None: + return self.head, self.head_dist + else: + return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ - if num_classes > 0 and self.num_tokens == 2 else nn.Identity() + if self.head_dist is not None: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) @@ -221,19 +225,21 @@ class PoolingVisionTransformer(nn.Module): cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x, cls_tokens = self.transformers((x, cls_tokens)) cls_tokens = self.norm(cls_tokens) - return cls_tokens + if self.head_dist is not None: + return cls_tokens[:, 0], cls_tokens[:, 1] + else: + return cls_tokens[:, 0] def forward(self, x): x = self.forward_features(x) - x_cls = self.head(x[:, 0]) - if self.num_tokens > 1: - x_dist = self.head_dist(x[:, 1]) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple if self.training and not torch.jit.is_scripting(): - return x_cls, x_dist + return x, x_dist else: - return (x_cls + x_dist) / 2 + return (x + x_dist) / 2 else: - return x_cls + return self.head(x) def checkpoint_filter_fn(state_dict, model): diff --git a/timm/models/registry.py b/timm/models/registry.py index 6927b6d6..f92219b2 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -65,11 +65,18 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module """ if module: - models = list(_module_to_models[module]) + all_models = list(_module_to_models[module]) else: - models = _model_entrypoints.keys() + all_models = _model_entrypoints.keys() if filter: - models = fnmatch.filter(models, filter) # include these models + models = [] + include_filters = filter if isinstance(filter, (tuple, list)) else [filter] + for f in include_filters: + include_models = fnmatch.filter(all_models, f) # include these models + if len(include_models): + models = set(models).union(include_models) + else: + models = all_models if exclude_filters: if not isinstance(exclude_filters, (tuple, list)): exclude_filters = [exclude_filters] diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 2f02f12a..66baa37a 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -638,12 +638,15 @@ class ResNet(nn.Module): self.num_features = 512 * block.expansion self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.init_weights(zero_init_last_bn=zero_init_last_bn) + + def init_weights(self, zero_init_last_bn=True): for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1.) - nn.init.constant_(m.bias, 0.) + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) if zero_init_last_bn: for m in self.modules(): if hasattr(m, 'zero_init_last_bn'): diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 250695a8..84b16bb2 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -35,9 +35,9 @@ import torch.nn as nn from functools import partial from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv from .registry import register_model -from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d +from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d def _cfg(url='', **kwargs): @@ -86,20 +86,10 @@ default_cfgs = { url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz', num_classes=21843), - - # trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now - # 'resnetv2_50x1_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'), - # 'resnetv2_50x3_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'), - # 'resnetv2_101x1_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'), - # 'resnetv2_101x3_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'), - # 'resnetv2_152x2_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'), - # 'resnetv2_152x4_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'), + 'resnetv2_50': _cfg( + input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic'), + 'resnetv2_50d': _cfg( + input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic', first_conv='stem.conv1'), } @@ -111,13 +101,6 @@ def make_div(v, divisor=8): return new_v -def tf2th(conv_weights): - """Possibly convert HWIO to OIHW.""" - if conv_weights.ndim == 4: - conv_weights = conv_weights.transpose([3, 2, 0, 1]) - return torch.from_numpy(conv_weights) - - class PreActBottleneck(nn.Module): """Pre-activation (v2) bottleneck block. @@ -152,6 +135,9 @@ class PreActBottleneck(nn.Module): self.conv3 = conv_layer(mid_chs, out_chs, 1) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + def zero_init_last_bn(self): + nn.init.zeros_(self.norm3.weight) + def forward(self, x): x_preact = self.norm1(x) @@ -198,6 +184,9 @@ class Bottleneck(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.act3 = act_layer(inplace=True) + def zero_init_last_bn(self): + nn.init.zeros_(self.norm3.weight) + def forward(self, x): # shortcut branch shortcut = x @@ -276,7 +265,7 @@ class ResNetStage(nn.Module): def create_resnetv2_stem( in_chs, out_chs=64, stem_type='', preact=True, - conv_layer=partial(StdConv2d, eps=1e-8), norm_layer=partial(GroupNormAct, num_groups=32)): + conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)): stem = OrderedDict() assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same') @@ -285,14 +274,17 @@ def create_resnetv2_stem( # A 3 deep 3x3 conv stack as in ResNet V1D models mid_chs = out_chs // 2 stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) + stem['norm1'] = norm_layer(mid_chs) stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) + stem['norm2'] = norm_layer(mid_chs) stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1) + if not preact: + stem['norm3'] = norm_layer(out_chs) else: # The usual 7x7 stem conv stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) - - if not preact: - stem['norm'] = norm_layer(out_chs) + if not preact: + stem['norm'] = norm_layer(out_chs) if 'fixed' in stem_type: # 'fixed' SAME padding approximation that is used in BiT models @@ -312,11 +304,12 @@ class ResNetV2(nn.Module): """Implementation of Pre-activation (v2) ResNet mode. """ - def __init__(self, layers, channels=(256, 512, 1024, 2048), - num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, - width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, - act_layer=nn.ReLU, conv_layer=partial(StdConv2d, eps=1e-8), - norm_layer=partial(GroupNormAct, num_groups=32), drop_rate=0., drop_path_rate=0.): + def __init__( + self, layers, channels=(256, 512, 1024, 2048), + num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, + act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), + drop_rate=0., drop_path_rate=0., zero_init_last_bn=True): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -354,12 +347,14 @@ class ResNetV2(nn.Module): self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) - for n, m in self.named_modules(): - if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)): - nn.init.normal_(m.weight, mean=0.0, std=0.01) - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + self.init_weights(zero_init_last_bn=zero_init_last_bn) + + def init_weights(self, zero_init_last_bn=True): + named_apply(partial(_init_weights, zero_init_last_bn=zero_init_last_bn), self) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix='resnet/'): + _load_weights(self, checkpoint_path, prefix) def get_classifier(self): return self.head.fc @@ -378,41 +373,59 @@ class ResNetV2(nn.Module): def forward(self, x): x = self.forward_features(x) x = self.head(x) - if not self.head.global_pool.is_identity(): - x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled) return x - def load_pretrained(self, checkpoint_path, prefix='resnet/'): - import numpy as np - weights = np.load(checkpoint_path) - with torch.no_grad(): - stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']) - if self.stem.conv.weight.shape[1] == 1: - self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True)) - # FIXME handle > 3 in_chans? - else: - self.stem.conv.weight.copy_(stem_conv_w) - self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma'])) - self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta'])) - if self.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: - self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) - self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias'])) - for i, (sname, stage) in enumerate(self.stages.named_children()): - for j, (bname, block) in enumerate(stage.blocks.named_children()): - convname = 'standardized_conv2d' - block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/' - block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel'])) - block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel'])) - block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel'])) - block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma'])) - block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma'])) - block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma'])) - block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta'])) - block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta'])) - block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta'])) - if block.downsample is not None: - w = weights[f'{block_prefix}a/proj/{convname}/kernel'] - block.downsample.conv.weight.copy_(tf2th(w)) + +def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True): + if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)): + nn.init.normal_(module.weight, mean=0.0, std=0.01) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif zero_init_last_bn and hasattr(module, 'zero_init_last_bn'): + module.zero_init_last_bn() + + +@torch.no_grad() +def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/'): + import numpy as np + + def t2p(conv_weights): + """Possibly convert HWIO to OIHW.""" + if conv_weights.ndim == 4: + conv_weights = conv_weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(conv_weights) + + weights = np.load(checkpoint_path) + stem_conv_w = adapt_input_conv( + model.stem.conv.weight.shape[1], t2p(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) + model.stem.conv.weight.copy_(stem_conv_w) + model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma'])) + model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta'])) + if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: + model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel'])) + model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias'])) + for i, (sname, stage) in enumerate(model.stages.named_children()): + for j, (bname, block) in enumerate(stage.blocks.named_children()): + cname = 'standardized_conv2d' + block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/' + block.conv1.weight.copy_(t2p(weights[f'{block_prefix}a/{cname}/kernel'])) + block.conv2.weight.copy_(t2p(weights[f'{block_prefix}b/{cname}/kernel'])) + block.conv3.weight.copy_(t2p(weights[f'{block_prefix}c/{cname}/kernel'])) + block.norm1.weight.copy_(t2p(weights[f'{block_prefix}a/group_norm/gamma'])) + block.norm2.weight.copy_(t2p(weights[f'{block_prefix}b/group_norm/gamma'])) + block.norm3.weight.copy_(t2p(weights[f'{block_prefix}c/group_norm/gamma'])) + block.norm1.bias.copy_(t2p(weights[f'{block_prefix}a/group_norm/beta'])) + block.norm2.bias.copy_(t2p(weights[f'{block_prefix}b/group_norm/beta'])) + block.norm3.bias.copy_(t2p(weights[f'{block_prefix}c/group_norm/beta'])) + if block.downsample is not None: + w = weights[f'{block_prefix}a/proj/{cname}/kernel'] + block.downsample.conv.weight.copy_(t2p(w)) def _create_resnetv2(variant, pretrained=False, **kwargs): @@ -425,130 +438,99 @@ def _create_resnetv2(variant, pretrained=False, **kwargs): **kwargs) +def _create_resnetv2_bit(variant, pretrained=False, **kwargs): + return _create_resnetv2( + variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs) + + @register_model def resnetv2_50x1_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_50x1_bitm', pretrained=pretrained, - layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs) @register_model def resnetv2_50x3_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_50x3_bitm', pretrained=pretrained, - layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs) @register_model def resnetv2_101x1_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_101x1_bitm', pretrained=pretrained, - layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs) @register_model def resnetv2_101x3_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_101x3_bitm', pretrained=pretrained, - layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs) @register_model def resnetv2_152x2_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_152x2_bitm', pretrained=pretrained, - layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) @register_model def resnetv2_152x4_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_152x4_bitm', pretrained=pretrained, - layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs) @register_model def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs): - return _create_resnetv2( + return _create_resnetv2_bit( 'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) + layers=[3, 4, 6, 3], width_factor=1, **kwargs) @register_model def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs): - return _create_resnetv2( + return _create_resnetv2_bit( 'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) + layers=[3, 4, 6, 3], width_factor=3, **kwargs) @register_model def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs): return _create_resnetv2( 'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) + layers=[3, 4, 23, 3], width_factor=1, **kwargs) @register_model def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs): - return _create_resnetv2( + return _create_resnetv2_bit( 'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) + layers=[3, 4, 23, 3], width_factor=3, **kwargs) @register_model def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs): - return _create_resnetv2( + return _create_resnetv2_bit( 'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) + layers=[3, 8, 36, 3], width_factor=2, **kwargs) @register_model def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs): - return _create_resnetv2( + return _create_resnetv2_bit( 'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) + layers=[3, 8, 36, 3], width_factor=4, **kwargs) -# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M. +@register_model +def resnetv2_50(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, **kwargs) -# @register_model -# def resnetv2_50x1_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_50x1_bits', pretrained=pretrained, -# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) -# -# -# @register_model -# def resnetv2_50x3_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_50x3_bits', pretrained=pretrained, -# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) -# -# -# @register_model -# def resnetv2_101x1_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_101x1_bits', pretrained=pretrained, -# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) -# -# -# @register_model -# def resnetv2_101x3_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_101x3_bits', pretrained=pretrained, -# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) -# -# -# @register_model -# def resnetv2_152x2_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_152x2_bits', pretrained=pretrained, -# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) -# -# -# @register_model -# def resnetv2_152x4_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_152x4_bits', pretrained=pretrained, -# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) -# + +@register_model +def resnetv2_50d(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, + stem_type='deep', avg_down=True, **kwargs) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index a845f505..2ee106d2 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -126,19 +126,18 @@ class WindowAttention(nn.Module): window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( @@ -210,7 +209,6 @@ class SwinTransformerBlock(nn.Module): shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 @@ -219,7 +217,7 @@ class SwinTransformerBlock(nn.Module): """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim @@ -236,8 +234,8 @@ class SwinTransformerBlock(nn.Module): self.norm1 = norm_layer(dim) self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) @@ -369,7 +367,6 @@ class BasicLayer(nn.Module): window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 @@ -379,7 +376,7 @@ class BasicLayer(nn.Module): """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() @@ -390,14 +387,11 @@ class BasicLayer(nn.Module): # build blocks self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) + SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer @@ -436,7 +430,6 @@ class SwinTransformer(nn.Module): window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 @@ -448,7 +441,7 @@ class SwinTransformer(nn.Module): def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, weight_init='', **kwargs): @@ -491,8 +484,9 @@ class SwinTransformer(nn.Module): num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, @@ -520,6 +514,13 @@ class SwinTransformer(nn.Module): def no_weight_decay_keywords(self): return {'relative_position_bias_table'} + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward_features(self, x): x = self.patch_embed(x) if self.absolute_pos_embed is not None: diff --git a/timm/models/twins.py b/timm/models/twins.py index 793d2ede..4aed09d9 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -278,6 +278,8 @@ class Twins(nn.Module): super().__init__() self.num_classes = num_classes self.depths = depths + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] img_size = to_2tuple(img_size) prev_chs = in_chans @@ -303,10 +305,10 @@ class Twins(nn.Module): self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) - self.norm = norm_layer(embed_dims[-1]) + self.norm = norm_layer(self.num_features) # classification head - self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # init weights self.apply(self._init_weights) @@ -320,7 +322,7 @@ class Twins(nn.Module): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def _init_weights(self, m): if isinstance(m, nn.Linear): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 5583ea3c..16631027 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d +from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier from .registry import register_model @@ -140,14 +140,14 @@ class Visformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111', - vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None): + vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None): super().__init__() + img_size = to_2tuple(img_size) self.num_classes = num_classes - self.num_features = self.embed_dim = embed_dim + self.embed_dim = embed_dim self.init_channels = init_channels self.img_size = img_size self.vit_stem = vit_stem - self.pool = pool self.conv_init = conv_init if isinstance(depth, (list, tuple)): self.stage_num1, self.stage_num2, self.stage_num3 = depth @@ -164,31 +164,31 @@ class Visformer(nn.Module): self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) - img_size //= 16 + img_size = [x // 16 for x in img_size] else: if self.init_channels is None: self.stem = None self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) - img_size //= 8 + img_size = [x // 8 for x in img_size] else: self.stem = nn.Sequential( nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), nn.BatchNorm2d(self.init_channels), nn.ReLU(inplace=True) ) - img_size //= 2 + img_size = [x // 2 for x in img_size] self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) - img_size //= 4 + img_size = [x // 4 for x in img_size] if self.pos_embed: if self.vit_stem: - self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) else: - self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size)) + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size)) self.pos_drop = nn.Dropout(p=drop_rate) self.stage1 = nn.ModuleList([ Block( @@ -199,14 +199,14 @@ class Visformer(nn.Module): for i in range(self.stage_num1) ]) - #stage2 + # stage2 if not self.vit_stem: self.patch_embed2 = PatchEmbed( img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) - img_size //= 2 + img_size = [x // 2 for x in img_size] if self.pos_embed: - self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) self.stage2 = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, @@ -221,9 +221,9 @@ class Visformer(nn.Module): self.patch_embed3 = PatchEmbed( img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) - img_size //= 2 + img_size = [x // 2 for x in img_size] if self.pos_embed: - self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size)) + self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size)) self.stage3 = nn.ModuleList([ Block( dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, @@ -234,11 +234,10 @@ class Visformer(nn.Module): ]) # head - if self.pool: - self.global_pooling = nn.AdaptiveAvgPool2d(1) - head_dim = embed_dim if self.vit_stem else embed_dim * 2 - self.norm = norm_layer(head_dim) - self.head = nn.Linear(head_dim, num_classes) + self.num_features = embed_dim if self.vit_stem else embed_dim * 2 + self.norm = norm_layer(self.num_features) + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.head = nn.Linear(self.num_features, num_classes) # weights init if self.pos_embed: @@ -267,7 +266,14 @@ class Visformer(nn.Module): if m.bias is not None: nn.init.constant_(m.bias, 0.) - def forward(self, x): + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): if self.stem is not None: x = self.stem(x) @@ -297,14 +303,13 @@ class Visformer(nn.Module): for b in self.stage3: x = b(x) - # head x = self.norm(x) - if self.pool: - x = self.global_pooling(x) - else: - x = x[:, :, 0, 0] + return x - x = self.head(x.view(x.size(0), -1)) + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + x = self.head(x) return x @@ -321,7 +326,7 @@ def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): @register_model def visformer_tiny(pretrained=False, **kwargs): model_cfg = dict( - img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, + init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg) @@ -331,7 +336,7 @@ def visformer_tiny(pretrained=False, **kwargs): @register_model def visformer_small(pretrained=False, **kwargs): model_cfg = dict( - img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, + init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ff74d836..c44358df 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -28,7 +28,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .registry import register_model @@ -47,9 +47,18 @@ def _cfg(url='', **kwargs): default_cfgs = { - # patch models (my experiments) + # FIXME weights coming + 'vit_tiny_patch16_224': _cfg( + url='', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), 'vit_small_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', + url='', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), + 'vit_small_patch32_224': _cfg( + url='', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), # patch models (weights ported from official Google JAX impl) @@ -97,29 +106,29 @@ default_cfgs = { num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # deit models (FB weights) - 'vit_deit_tiny_patch16_224': _cfg( + 'deit_tiny_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), - 'vit_deit_small_patch16_224': _cfg( + 'deit_small_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), - 'vit_deit_base_patch16_224': _cfg( + 'deit_base_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), - 'vit_deit_base_patch16_384': _cfg( + 'deit_base_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_deit_tiny_distilled_patch16_224': _cfg( + 'deit_tiny_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', classifier=('head', 'head_dist')), - 'vit_deit_small_distilled_patch16_224': _cfg( + 'deit_small_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', classifier=('head', 'head_dist')), - 'vit_deit_base_distilled_patch16_224': _cfg( + 'deit_base_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', classifier=('head', 'head_dist')), - 'vit_deit_base_distilled_patch16_384': _cfg( + 'deit_base_distilled_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), - # ViT ImageNet-21K-P pretraining + # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil_in21k': _cfg( url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, @@ -133,11 +142,11 @@ default_cfgs = { class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -161,12 +170,11 @@ class Attention(nn.Module): class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) @@ -190,7 +198,7 @@ class VisionTransformer(nn.Module): """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, + num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None, weight_init=''): """ @@ -204,7 +212,6 @@ class VisionTransformer(nn.Module): num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True - qk_scale (float): override default qk scale of head_dim ** -0.5 if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set distilled (bool): model includes a distillation token and head as in DeiT models drop_rate (float): dropout rate @@ -233,8 +240,8 @@ class VisionTransformer(nn.Module): dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) @@ -254,16 +261,17 @@ class VisionTransformer(nn.Module): if distilled: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() - # Weight init - assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') - head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) if self.dist_token is not None: trunc_normal_(self.dist_token, std=.02) - if weight_init.startswith('jax'): + if mode.startswith('jax'): # leave cls token as zeros to match jax impl - for n, m in self.named_modules(): - _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) + named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) else: trunc_normal_(self.cls_token, std=.02) self.apply(_init_vit_weights) @@ -272,6 +280,10 @@ class VisionTransformer(nn.Module): # this fn left here for compat with downstream users _init_vit_weights(m) + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token', 'dist_token'} @@ -317,39 +329,92 @@ class VisionTransformer(nn.Module): return x -def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False): +def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): """ ViT weight initialization * When called without n, head_bias, jax_impl args it will behave exactly the same as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl """ - if isinstance(m, nn.Linear): - if n.startswith('head'): - nn.init.zeros_(m.weight) - nn.init.constant_(m.bias, head_bias) - elif n.startswith('pre_logits'): - lecun_normal_(m.weight) - nn.init.zeros_(m.bias) + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + lecun_normal_(module.weight) + nn.init.zeros_(module.bias) else: if jax_impl: - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - if 'mlp' in n: - nn.init.normal_(m.bias, std=1e-6) + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) else: - nn.init.zeros_(m.bias) + nn.init.zeros_(module.bias) else: - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif jax_impl and isinstance(m, nn.Conv2d): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2d): # NOTE conv was left to pytorch default in my original init - lecun_normal_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.LayerNorm): - nn.init.zeros_(m.bias) - nn.init.ones_(m.weight) + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if t and w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif t and w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif t and w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix: + prefix = 'opt/target/' if 'opt/target/embedding/kernel' in w else prefix + + input_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(input_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + model.pos_embed.copy_(_n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T, + _n2p(w[f'{mha_prefix}key/kernel'], t=False).flatten(1).T, + _n2p(w[f'{mha_prefix}value/kernel'], t=False).flatten(1).T])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1), + _n2p(w[f'{mha_prefix}key/bias'], t=False).reshape(-1), + _n2p(w[f'{mha_prefix}value/bias'], t=False).reshape(-1)])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + block.mlp.fc1.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/kernel'])) + block.mlp.fc1.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/bias'])) + block.mlp.fc2.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/kernel'])) + block.mlp.fc2.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): @@ -417,23 +482,34 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw return model +@register_model +def vit_tiny_patch16_224(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_patch16_224(pretrained=False, **kwargs): - """ My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3. - NOTE: - * this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6 - * this model does not have a bias for QKV (unlike the official ViT and DeiT models) + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., - qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs) - if pretrained: - # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model - model_kwargs.setdefault('qk_scale', 768 ** -0.5) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model +@register_model +def vit_small_patch32_224(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). @@ -569,86 +645,86 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): @register_model -def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): +def deit_tiny_patch16_224(pretrained=False, **kwargs): """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_deit_small_patch16_224(pretrained=False, **kwargs): +def deit_small_patch16_224(pretrained=False, **kwargs): """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_deit_base_patch16_224(pretrained=False, **kwargs): +def deit_base_patch16_224(pretrained=False, **kwargs): """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_deit_base_patch16_384(pretrained=False, **kwargs): +def deit_base_patch16_384(pretrained=False, **kwargs): """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): +def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer( - 'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) return model @register_model -def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs): +def deit_small_distilled_patch16_224(pretrained=False, **kwargs): """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer( - 'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) return model @register_model -def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs): +def deit_base_distilled_patch16_224(pretrained=False, **kwargs): """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer( - 'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) return model @register_model -def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs): +def deit_base_distilled_patch16_384(pretrained=False, **kwargs): """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer( - 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) + 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 7fc0cc88..c807ee9a 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -46,8 +46,8 @@ default_cfgs = { input_size=(3, 384, 384), crop_pct=1.0), # hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones) - 'vit_tiny_r_s16_p8_224': _cfg(), - 'vit_small_r_s16_p8_224': _cfg(), + 'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), + 'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), 'vit_small_r20_s16_p2_224': _cfg(), 'vit_small_r20_s16_224': _cfg(), 'vit_small_r26_s32_224': _cfg(), @@ -57,10 +57,14 @@ default_cfgs = { 'vit_large_r50_s32_224': _cfg(), # hybrid models (using timm resnet backbones) - 'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'vit_small_resnet50d_s16_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'vit_base_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'vit_base_resnet50d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'vit_small_resnet26d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_small_resnet50d_s16_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_base_resnet26d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_base_resnet50d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), } @@ -140,12 +144,6 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): return model -@register_model -def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): - # NOTE this is forwarding to model def above for backwards compatibility - return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs) - - @register_model def vit_base_r50_s16_384(pretrained=False, **kwargs): """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). @@ -158,12 +156,6 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs): return model -@register_model -def vit_base_resnet50_384(pretrained=False, **kwargs): - # NOTE this is forwarding to model def above for backwards compatibility - return vit_base_r50_s16_384(pretrained=pretrained, **kwargs) - - @register_model def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. From 0020268d9b292a3b8ac82dcb1e21039ca32b0823 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 12 Jun 2021 23:31:24 -0700 Subject: [PATCH 05/16] Try lower max size for non_std default_cfg test --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index ac156806..52a8023a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -174,7 +174,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size): cfg = model.default_cfg input_size = _get_input_size(model_name=model_name, target=TARGET_FWD_SIZE) - if max(input_size) > MAX_FWD_SIZE: + if max(input_size) > 320: # FIXME const pytest.skip("Fixed input size model > limit.") input_tensor = torch.randn((batch_size, *input_size)) From 8319e0c37357e162f0e870c08621f145a8e76830 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 13 Jun 2021 12:31:06 -0700 Subject: [PATCH 06/16] Add file docstring to std_conv.py --- timm/models/layers/std_conv.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py index 49b35875..3ccc16e1 100644 --- a/timm/models/layers/std_conv.py +++ b/timm/models/layers/std_conv.py @@ -1,3 +1,21 @@ +""" Convolution with Weight Standardization (StdConv and ScaledStdConv) + +StdConv: +@article{weightstandardization, + author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille}, + title = {Weight Standardization}, + journal = {arXiv preprint arXiv:1903.10520}, + year = {2019}, +} +Code: https://github.com/joe-siyuan-qiao/WeightStandardization + +ScaledStdConv: +Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Hacked together by / copyright Ross Wightman, 2021. +""" import torch import torch.nn as nn import torch.nn.functional as F @@ -5,12 +23,6 @@ import torch.nn.functional as F from .padding import get_padding, get_padding_value, pad_same -def get_weight(module): - std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = (module.weight - mean) / (std + module.eps) - return weight - - class StdConv2d(nn.Conv2d): """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. @@ -30,7 +42,7 @@ class StdConv2d(nn.Conv2d): def forward(self, x): weight = F.batch_norm( self.weight.view(1, self.out_channels, -1), None, None, - eps=self.eps, training=True, momentum=0.).reshape_as(self.weight) + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -56,7 +68,7 @@ class StdConv2dSame(nn.Conv2d): x = pad_same(x, self.kernel_size, self.stride, self.dilation) weight = F.batch_norm( self.weight.view(1, self.out_channels, -1), None, None, - eps=self.eps, training=True, momentum=0.).reshape_as(self.weight) + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -86,7 +98,7 @@ class ScaledStdConv2d(nn.Conv2d): weight = F.batch_norm( self.weight.view(1, self.out_channels, -1), None, None, weight=(self.gain * self.scale).view(-1), - eps=self.eps, training=True, momentum=0.).reshape_as(self.weight) + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) @@ -117,5 +129,5 @@ class ScaledStdConv2dSame(nn.Conv2d): weight = F.batch_norm( self.weight.view(1, self.out_channels, -1), None, None, weight=(self.gain * self.scale).view(-1), - eps=self.eps, training=True, momentum=0.).reshape_as(self.weight) + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) From b9cfb64412e367a1352d46f00906453d0274282c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 14 Jun 2021 12:31:44 -0700 Subject: [PATCH 07/16] Support npz custom load for vision transformer hybrid models. Add posembed rescale for npz load. --- timm/models/layers/pool2d_same.py | 10 +- timm/models/vision_transformer.py | 96 ++++++++++++----- timm/models/vision_transformer_hybrid.py | 131 ++++++++++++++++++----- 3 files changed, 181 insertions(+), 56 deletions(-) diff --git a/timm/models/layers/pool2d_same.py b/timm/models/layers/pool2d_same.py index 5fcd0f1f..4c2a1c44 100644 --- a/timm/models/layers/pool2d_same.py +++ b/timm/models/layers/pool2d_same.py @@ -27,7 +27,8 @@ class AvgPool2dSame(nn.AvgPool2d): super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) def forward(self, x): - return avg_pool2d_same( + x = pad_same(x, self.kernel_size, self.stride) + return F.avg_pool2d( x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) @@ -41,14 +42,15 @@ def max_pool2d_same( class MaxPool2dSame(nn.MaxPool2d): """ Tensorflow like 'SAME' wrapper for 2D max pooling """ - def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True): + def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) dilation = to_2tuple(dilation) - super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad) + super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) def forward(self, x): - return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) + x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) + return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index c44358df..7dd9137e 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -52,6 +52,10 @@ default_cfgs = { url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), + 'vit_tiny_patch16_384': _cfg( + url='', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 + ), 'vit_small_patch16_224': _cfg( url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), @@ -60,6 +64,14 @@ default_cfgs = { url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), + 'vit_small_patch16_384': _cfg( + url='', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 + ), + 'vit_small_patch32_384': _cfg( + url='', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 + ), # patch models (weights ported from official Google JAX impl) 'vit_base_patch16_224': _cfg( @@ -102,6 +114,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 'vit_huge_patch14_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', hf_hub='timm/vit_huge_patch14_224_in21k', num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), @@ -371,24 +384,53 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = import numpy as np def _n2p(w, t=True): - if t and w.ndim == 4: - w = w.transpose([3, 2, 0, 1]) - elif t and w.ndim == 3: - w = w.transpose([2, 0, 1]) - elif t and w.ndim == 2: - w = w.transpose([1, 0]) + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) return torch.from_numpy(w) w = np.load(checkpoint_path) - if not prefix: - prefix = 'opt/target/' if 'opt/target/embedding/kernel' in w else prefix - - input_conv_w = adapt_input_conv( - model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) - model.patch_embed.proj.weight.copy_(input_conv_w) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) - model.pos_embed.copy_(_n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: @@ -396,23 +438,18 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) - mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' block.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T, - _n2p(w[f'{mha_prefix}key/kernel'], t=False).flatten(1).T, - _n2p(w[f'{mha_prefix}value/kernel'], t=False).flatten(1).T])) + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1), - _n2p(w[f'{mha_prefix}key/bias'], t=False).reshape(-1), - _n2p(w[f'{mha_prefix}value/bias'], t=False).reshape(-1)])) + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) - block.mlp.fc1.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/kernel'])) - block.mlp.fc1.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/bias'])) - block.mlp.fc2.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/kernel'])) - block.mlp.fc2.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) @@ -478,6 +515,7 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw default_cfg=default_cfg, representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, + pretrained_custom_load='npz' in default_cfg['url'], **kwargs) return model @@ -510,6 +548,16 @@ def vit_small_patch32_224(pretrained=False, **kwargs): return model +@register_model +def vit_small_patch16_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index c807ee9a..1bfe6685 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -35,26 +35,34 @@ def _cfg(url='', **kwargs): default_cfgs = { - # hybrid in-21k models (weights ported from official Google JAX impl where they exist) - 'vit_base_r50_s16_224_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', - num_classes=21843, crop_pct=0.9), - - # hybrid in-1k models (weights ported from official JAX impl) - 'vit_base_r50_s16_384': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', - input_size=(3, 384, 384), crop_pct=1.0), - - # hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones) + # hybrid in-1k models (weights ported from official JAX impl where they exist) 'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), + 'vit_tiny_r_s16_p8_384': _cfg( + first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), 'vit_small_r20_s16_p2_224': _cfg(), 'vit_small_r20_s16_224': _cfg(), 'vit_small_r26_s32_224': _cfg(), + 'vit_small_r26_s32_384': _cfg( + input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_r20_s16_224': _cfg(), 'vit_base_r26_s32_224': _cfg(), 'vit_base_r50_s16_224': _cfg(), + 'vit_base_r50_s16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', + input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_r50_s32_224': _cfg(), + 'vit_large_r50_s32_384': _cfg(), + + # hybrid in-21k models (weights ported from official Google JAX impl where they exist) + 'vit_small_r26_s32_224_in21k': _cfg( + num_classes=21843, crop_pct=0.9), + 'vit_small_r26_s32_384_in21k': _cfg( + num_classes=21843, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_r50_s16_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', + num_classes=21843, crop_pct=0.9), + 'vit_large_r50_s32_224_in21k': _cfg(num_classes=21843, crop_pct=0.9), # hybrid models (using timm resnet backbones) 'vit_small_resnet26d_224': _cfg( @@ -99,7 +107,8 @@ class HybridEmbed(nn.Module): else: feature_dim = self.backbone.num_features assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 - self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1] + self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): @@ -133,37 +142,35 @@ def _resnetv2(layers=(3, 4, 9), **kwargs): @register_model -def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): - """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. +def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. """ - backbone = _resnetv2(layers=(3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer_hybrid( - 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_r50_s16_384(pretrained=False, **kwargs): - """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. +def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384. """ - backbone = _resnetv2((3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer_hybrid( - 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @register_model -def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): - """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. +def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384. """ backbone = _resnetv2(layers=(), **kwargs) model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer_hybrid( - 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -212,6 +219,17 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs): return model +@register_model +def vit_small_r26_s32_384(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_r20_s16_224(pretrained=False, **kwargs): """ R20+ViT-B/S16 hybrid. @@ -245,17 +263,74 @@ def vit_base_r50_s16_224(pretrained=False, **kwargs): return model +@register_model +def vit_base_r50_s16_384(pretrained=False, **kwargs): + """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + backbone = _resnetv2((3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_large_r50_s32_224(pretrained=False, **kwargs): """ R50+ViT-L/S32 hybrid. """ backbone = _resnetv2((3, 4, 6, 3), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer_hybrid( 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model +@register_model +def vit_large_r50_s32_384(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_384_in21k(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_384_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): + """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + backbone = _resnetv2(layers=(3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_resnet26d_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. From 511a8e8c96dcbac7014aec8355f38a658ef40e49 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 14 Jun 2021 17:01:12 -0700 Subject: [PATCH 08/16] Add official ResMLP weights. --- timm/models/mlp_mixer.py | 146 +++++++++++++++++++++++++++++++++++---- 1 file changed, 134 insertions(+), 12 deletions(-) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 637e00ea..db3a1be5 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -14,8 +14,9 @@ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2 year={2021} } -Also supporting preliminary (not verified) implementations of ResMlp, gMLP, and possibly more... +Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP +Code: https://github.com/facebookresearch/deit Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 @misc{touvron2021resmlp, title={ResMLP: Feedforward networks for image classification with data-efficient training}, @@ -94,11 +95,36 @@ default_cfgs = dict( gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - resmlp_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_12_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), resmlp_24_224=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=0.89), - resmlp_36_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth', + #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_36_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_big_24_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_12_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_24_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_36_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_big_24_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_big_24_224_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), gmlp_ti16_224=_cfg(), gmlp_s16_224=_cfg(), @@ -266,7 +292,7 @@ class MlpMixer(nn.Module): return x -def _init_weights(module: nn.Module, name: str, head_bias: float = 0.): +def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): """ Mixer weight initialization (trying to match Flax defaults) """ if isinstance(module, nn.Linear): @@ -274,12 +300,19 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0.): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) else: - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - if 'mlp' in name: - nn.init.normal_(module.bias, std=1e-6) - else: + if flax: + # Flax defaults + lecun_normal_(module.weight) + if module.bias is not None: nn.init.zeros_(module.bias) + else: + # like MLP init in vit (my original init) + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): lecun_normal_(module.weight) if module.bias is not None: @@ -293,6 +326,23 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0.): module.init_weights() +def checkpoint_filter_fn(state_dict, model): + """ Remap checkpoints if needed """ + if 'patch_embed.proj.weight' in state_dict: + # Remap FB ResMlp models -> timm + out_dict = {} + for k, v in state_dict.items(): + k = k.replace('patch_embed.', 'stem.') + k = k.replace('attn.', 'linear_tokens.') + k = k.replace('mlp.', 'mlp_channels.') + k = k.replace('gamma_', 'ls') + if k.endswith('.alpha') or k.endswith('.beta'): + v = v.reshape(1, 1, -1) + out_dict[k] = v + return out_dict + return state_dict + + def _create_mixer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for MLP-Mixer models.') @@ -300,6 +350,7 @@ def _create_mixer(variant, pretrained=False, **kwargs): model = build_model_with_cfg( MlpMixer, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model @@ -458,11 +509,82 @@ def resmlp_36_224(pretrained=False, **kwargs): """ model_args = dict( patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4, - block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args) return model +@register_model +def resmlp_big_24_224(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_12_distilled_224(pretrained=False, **kwargs): + """ ResMLP-12 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_12_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_24_distilled_224(pretrained=False, **kwargs): + """ ResMLP-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_24_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_36_distilled_224(pretrained=False, **kwargs): + """ ResMLP-36 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_36_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_big_24_distilled_224(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_224_in22ft1k', pretrained=pretrained, **model_args) + return model + + @register_model def gmlp_ti16_224(pretrained=False, **kwargs): """ gMLP-Tiny From 1228f5a3d84afb7da614387be681b08a3dc8317f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 18 Jun 2021 11:40:33 -0700 Subject: [PATCH 09/16] Add BiT distilled 50x1 and teacher 152x2 models from 'A good teacher is patient and consistent' paper. --- timm/models/resnetv2.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 84b16bb2..054b0af1 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -11,6 +11,7 @@ https://github.com/google-research/vision_transformer Thanks to the Google team for the above two repositories and associated papers: * Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370 * An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929 +* Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020. """ @@ -86,6 +87,16 @@ default_cfgs = { url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz', num_classes=21843), + 'resnetv2_50x1_bit_distilled': _cfg( + url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz', + input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic'), + 'resnetv2_152x2_bit_teacher': _cfg( + url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz', + input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic'), + 'resnetv2_152x2_bit_teacher_384': _cfg( + url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz', + input_size=(3, 384, 384), crop_pct=1.0, interpolation='bicubic'), + 'resnetv2_50': _cfg( input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic'), 'resnetv2_50d': _cfg( @@ -521,6 +532,33 @@ def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs): layers=[3, 8, 36, 3], width_factor=4, **kwargs) +@register_model +def resnetv2_50x1_bit_distilled(pretrained=False, **kwargs): + """ ResNetV2-50x1-BiT Distilled + Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + """ + return _create_resnetv2_bit( + 'resnetv2_50x1_bit_distilled', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_152x2_bit_teacher(pretrained=False, **kwargs): + """ ResNetV2-152x2-BiT Teacher + Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + """ + return _create_resnetv2_bit( + 'resnetv2_152x2_bit_teacher', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) + + +@register_model +def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs): + """ ResNetV2-152xx-BiT Teacher @ 384x384 + Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + """ + return _create_resnetv2_bit( + 'resnetv2_152x2_bit_teacher_384', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) + + @register_model def resnetv2_50(pretrained=False, **kwargs): return _create_resnetv2( From 8257b86550b8453b658e386498d4e643d6bf8d38 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 18 Jun 2021 16:16:06 -0700 Subject: [PATCH 10/16] Fix up resnetv2 bit/bitm model default res --- timm/models/resnetv2.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 054b0af1..a3c89532 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -44,8 +44,8 @@ from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, creat def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 480, 480), 'pool_size': (7, 7), - 'crop_pct': 1.0, 'interpolation': 'bilinear', + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'stem.conv', 'classifier': 'head.fc', **kwargs @@ -55,17 +55,23 @@ def _cfg(url='', **kwargs): default_cfgs = { # pretrained on imagenet21k, finetuned on imagenet1k 'resnetv2_50x1_bitm': _cfg( - url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz'), + url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), 'resnetv2_50x3_bitm': _cfg( - url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz'), + url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), 'resnetv2_101x1_bitm': _cfg( - url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz'), + url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), 'resnetv2_101x3_bitm': _cfg( - url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz'), + url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), 'resnetv2_152x2_bitm': _cfg( - url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz'), + url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), 'resnetv2_152x4_bitm': _cfg( - url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz'), + url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz', + input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0), # only one at 480x480? # trained on imagenet-21k 'resnetv2_50x1_bitm_in21k': _cfg( @@ -89,18 +95,18 @@ default_cfgs = { 'resnetv2_50x1_bit_distilled': _cfg( url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz', - input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic'), + interpolation='bicubic'), 'resnetv2_152x2_bit_teacher': _cfg( url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz', - input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic'), + interpolation='bicubic'), 'resnetv2_152x2_bit_teacher_384': _cfg( url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz', - input_size=(3, 384, 384), crop_pct=1.0, interpolation='bicubic'), + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'), 'resnetv2_50': _cfg( - input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic'), + interpolation='bicubic'), 'resnetv2_50d': _cfg( - input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic', first_conv='stem.conv1'), + interpolation='bicubic', first_conv='stem.conv1'), } From b319eb5b5d8d29d109a1ca33bd0de0a1ac0d329c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 18 Jun 2021 16:16:49 -0700 Subject: [PATCH 11/16] Update ViT weights, more details to be added before merge. --- timm/models/vision_transformer.py | 264 ++++++++++++++--------- timm/models/vision_transformer_hybrid.py | 128 +++++------ 2 files changed, 211 insertions(+), 181 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 7dd9137e..b8fc6fa5 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg, named_apply, adapt_input_conv from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .registry import register_model @@ -40,106 +40,116 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } default_cfgs = { - # FIXME weights coming + # patch models (weights from official Google JAX impl) 'vit_tiny_patch16_224': _cfg( - url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - ), + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_tiny_patch16_384': _cfg( - url='', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 - ), - 'vit_small_patch16_224': _cfg( - url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - ), + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch32_224': _cfg( - url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - ), - 'vit_small_patch16_384': _cfg( - url='', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 - ), + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_small_patch32_384': _cfg( - url='', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 - ), - - # patch models (weights ported from official Google JAX impl) - 'vit_base_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - ), + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch32_224': _cfg( - url='', # no official model weights for this combo, only for in21k - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - 'vit_base_patch16_384': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_base_patch32_384': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), - 'vit_large_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_base_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_patch32_224': _cfg( url='', # no official model weights for this combo, only for in21k - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - 'vit_large_patch16_384': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + ), 'vit_large_patch32_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), - # patch models, imagenet21k (weights ported from official Google JAX impl) - 'vit_base_patch16_224_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth', - num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + # patch models, imagenet21k (weights from official Google JAX impl) + 'vit_tiny_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), 'vit_base_patch32_224_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth', - num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - 'vit_large_patch16_224_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth', - num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), 'vit_large_patch32_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', - num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + num_classes=21843), + 'vit_large_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + num_classes=21843), 'vit_huge_patch14_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', hf_hub='timm/vit_huge_patch14_224_in21k', - num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + num_classes=21843), # deit models (FB weights) 'deit_tiny_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'deit_small_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), + url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'deit_base_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'deit_base_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', - input_size=(3, 384, 384), crop_pct=1.0), + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), 'deit_tiny_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', - classifier=('head', 'head_dist')), + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 'deit_small_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', - classifier=('head', 'head_dist')), + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 'deit_base_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', - classifier=('head', 'head_dist')), + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 'deit_base_distilled_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', - input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, + classifier=('head', 'head_dist')), # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil_in21k': _cfg( @@ -530,12 +540,11 @@ def vit_tiny_patch16_224(pretrained=False, **kwargs): @register_model -def vit_small_patch16_224(pretrained=False, **kwargs): - """ ViT-Small (ViT-S/16) - NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper +def vit_tiny_patch16_384(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) @ 384x384. """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -543,28 +552,37 @@ def vit_small_patch16_224(pretrained=False, **kwargs): def vit_small_patch32_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/32) """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_small_patch16_384(pretrained=False, **kwargs): +def vit_small_patch32_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) at 384x384. + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_patch16_224(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. +def vit_small_patch16_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -577,6 +595,26 @@ def vit_base_patch32_224(pretrained=False, **kwargs): return model +@register_model +def vit_base_patch32_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_patch16_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). @@ -588,31 +626,31 @@ def vit_base_patch16_384(pretrained=False, **kwargs): @register_model -def vit_base_patch32_384(pretrained=False, **kwargs): - """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. +def vit_large_patch32_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ - model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_large_patch16_224(pretrained=False, **kwargs): +def vit_large_patch32_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_large_patch32_224(pretrained=False, **kwargs): - """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. +def vit_large_patch16_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -627,23 +665,32 @@ def vit_large_patch16_384(pretrained=False, **kwargs): @register_model -def vit_large_patch32_384(pretrained=False, **kwargs): - """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. +def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_patch16_224_in21k(pretrained=False, **kwargs): - """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). +def vit_small_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -659,13 +706,13 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs): @register_model -def vit_large_patch16_224_in21k(pretrained=False, **kwargs): - """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). +def vit_base_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) - model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -680,6 +727,17 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs): return model +@register_model +def vit_large_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 1bfe6685..a53419a0 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -35,34 +35,51 @@ def _cfg(url='', **kwargs): default_cfgs = { - # hybrid in-1k models (weights ported from official JAX impl where they exist) - 'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), + # hybrid in-1k models (weights from official JAX impl where they exist) + 'vit_tiny_r_s16_p8_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + first_conv='patch_embed.backbone.conv'), 'vit_tiny_r_s16_p8_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), - 'vit_small_r20_s16_p2_224': _cfg(), - 'vit_small_r20_s16_224': _cfg(), - 'vit_small_r26_s32_224': _cfg(), + 'vit_small_r26_s32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', + ), 'vit_small_r26_s32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_r20_s16_224': _cfg(), 'vit_base_r26_s32_224': _cfg(), 'vit_base_r50_s16_224': _cfg(), 'vit_base_r50_s16_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_large_r50_s32_224': _cfg(), - 'vit_large_r50_s32_384': _cfg(), - - # hybrid in-21k models (weights ported from official Google JAX impl where they exist) + 'vit_large_r50_s32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz' + ), + 'vit_large_r50_s32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0 + ), + + # hybrid in-21k models (weights from official Google JAX impl where they exist) + 'vit_tiny_r_s16_p8_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i1k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'), 'vit_small_r26_s32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz', num_classes=21843, crop_pct=0.9), - 'vit_small_r26_s32_384_in21k': _cfg( - num_classes=21843, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_r50_s16_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', num_classes=21843, crop_pct=0.9), - 'vit_large_r50_s32_224_in21k': _cfg(num_classes=21843, crop_pct=0.9), + 'vit_large_r50_s32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843, crop_pct=0.9), # hybrid models (using timm resnet backbones) 'vit_small_resnet26d_224': _cfg( @@ -163,51 +180,6 @@ def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): return model -@register_model -def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): - """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384. - """ - backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_r_s16_p8_224(pretrained=False, **kwargs): - """ R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224. - """ - backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_small_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) - - return model - - -@register_model -def vit_small_r20_s16_p2_224(pretrained=False, **kwargs): - """ R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224. - """ - backbone = _resnetv2((2, 4), **kwargs) - model_kwargs = dict(patch_size=2, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_small_r20_s16_p2_224', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_r20_s16_224(pretrained=False, **kwargs): - """ R20+ViT-S/S16 hybrid. - """ - backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_small_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_small_r26_s32_224(pretrained=False, **kwargs): """ R26+ViT-S/S32 hybrid. @@ -230,17 +202,6 @@ def vit_small_r26_s32_384(pretrained=False, **kwargs): return model -@register_model -def vit_base_r20_s16_224(pretrained=False, **kwargs): - """ R20+ViT-B/S16 hybrid. - """ - backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_base_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_base_r26_s32_224(pretrained=False, **kwargs): """ R26+ViT-B/S32 hybrid. @@ -298,24 +259,24 @@ def vit_large_r50_s32_384(pretrained=False, **kwargs): @register_model -def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): - """ R26+ViT-S/S32 hybrid. +def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k. """ - backbone = _resnetv2((2, 2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer_hybrid( - 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + 'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @register_model -def vit_small_r26_s32_384_in21k(pretrained=False, **kwargs): - """ R26+ViT-S/S32 hybrid. +def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. ImageNet-21k. """ backbone = _resnetv2((2, 2, 2, 2), **kwargs) model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer_hybrid( - 'vit_small_r26_s32_384_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -331,6 +292,17 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): return model +@register_model +def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. ImageNet-21k. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_resnet26d_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. From 4c09a2f169587bb2b2ca35fb23e432a66038d8d8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 18 Jun 2021 16:17:34 -0700 Subject: [PATCH 12/16] Bump version 0.4.12 --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index d4f33464..94c48197 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.11' +__version__ = '0.4.12' From 8f4a0222edae291c9fbb3636f23fe4299b7d523f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 18 Jun 2021 16:49:28 -0700 Subject: [PATCH 13/16] Add GMixer-24 MLP model weights, trained w/ TPU + PyTorch XLA --- timm/models/mlp_mixer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index db3a1be5..7a87eb36 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -93,7 +93,9 @@ default_cfgs = dict( ), gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + gmixer_24_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), resmlp_12_224=_cfg( url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth', @@ -457,11 +459,11 @@ def mixer_b16_224_miil_in21k(pretrained=False, **kwargs): @register_model def gmixer_12_224(pretrained=False, **kwargs): - """ Glu-Mixer-12 224x224 (short & fat) + """ Glu-Mixer-12 224x224 Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer """ model_args = dict( - patch_size=16, num_blocks=12, embed_dim=512, mlp_ratio=(1.0, 6.0), + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=(1.0, 4.0), mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args) return model @@ -469,11 +471,11 @@ def gmixer_12_224(pretrained=False, **kwargs): @register_model def gmixer_24_224(pretrained=False, **kwargs): - """ Glu-Mixer-24 224x224 (tall & slim) + """ Glu-Mixer-24 224x224 Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer """ model_args = dict( - patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 6.0), + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 4.0), mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args) return model From 26f04a8e3ef7c581f4766cd34e71d94105c32064 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 19 Jun 2021 16:39:36 -0700 Subject: [PATCH 14/16] Fix a weight link --- timm/models/vision_transformer_hybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index a53419a0..30330d39 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -69,7 +69,7 @@ default_cfgs = { # hybrid in-21k models (weights from official Google JAX impl where they exist) 'vit_tiny_r_s16_p8_224_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i1k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'), 'vit_small_r26_s32_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz', From 381b2797858248619fe8007fa1c5f5a5d4ab3919 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 19 Jun 2021 22:28:44 -0700 Subject: [PATCH 15/16] Add hybrid model fwds back --- tests/test_models.py | 2 +- timm/models/vision_transformer_hybrid.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 52a8023a..0a770784 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -173,7 +173,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size): state_dict = model.state_dict() cfg = model.default_cfg - input_size = _get_input_size(model_name=model_name, target=TARGET_FWD_SIZE) + input_size = _get_input_size(model=model) if max(input_size) > 320: # FIXME const pytest.skip("Fixed input size model > limit.") diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 30330d39..5d725c58 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -236,6 +236,12 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs): return model +@register_model +def vit_base_resnet50_384(pretrained=False, **kwargs): + # DEPRECATED this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_384(pretrained=pretrained, **kwargs) + + @register_model def vit_large_r50_s32_224(pretrained=False, **kwargs): """ R50+ViT-L/S32 hybrid. @@ -292,6 +298,12 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): return model +@register_model +def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): + # DEPRECATED this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs) + + @register_model def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): """ R50+ViT-L/S32 hybrid. ImageNet-21k. From 9c9755a80869f7b42eb63c5bf9477aae3056615e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 20 Jun 2021 17:46:06 -0700 Subject: [PATCH 16/16] AugReg release --- README.md | 19 +++++++++++++++++++ timm/models/vision_transformer.py | 11 ++++++++--- timm/models/vision_transformer_hybrid.py | 12 ++++++++---- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 704bc32c..76261ccc 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,25 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### June 20, 2021 +* Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270) + * .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg) + * See [example notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb) from official impl for navigating the augreg weights + * Replaced all default weights w/ best AugReg variant (if possible). All AugReg 21k classifiers work. + * Highlights: `vit_large_patch16_384` (87.1 top-1), `vit_large_r50_s32_384` (86.2 top-1), `vit_base_patch16_384` (86.0 top-1) + * `vit_deit_*` renamed to just `deit_*` + * Remove my old small model, replace with DeiT compatible small w/ AugReg weights +* Add 1st training of my `gmixer_24_224` MLP /w GLU, 78.1 top-1 w/ 25M params. +* Add weights from official ResMLP release (https://github.com/facebookresearch/deit) +* Add `eca_nfnet_l2` weights from my 'lightweight' series. 84.7 top-1 at 384x384. +* Add distilled BiT 50x1 student and 152x2 Teacher weights from [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237) +* NFNets and ResNetV2-BiT models work w/ Pytorch XLA now + * weight standardization uses F.batch_norm instead of std_mean (std_mean wasn't lowered) + * eps values adjusted, will be slight differences but should be quite close +* Improve test coverage and classifier interface of non-conv (vision transformer and mlp) models +* Cleanup a few classifier / flatten details for models w/ conv classifiers or early global pool +* Please report any regressions, this PR touched quite a few models. + ### June 8, 2021 * Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1. * Add ResNet51-Q model w/ pretrained weights at 82.36 top-1. diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b8fc6fa5..89fba7de 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1,7 +1,12 @@ """ Vision Transformer (ViT) in PyTorch -A PyTorch implement of Vision Transformers as described in -'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.TODO The official jax code is released and available at https://github.com/google-research/vision_transformer @@ -15,7 +20,7 @@ for some einops/einsum fun * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT * Bert reference code checks against Huggingface Transformers and Tensorflow Bert -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ import math import logging diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 5d725c58..d5f0a537 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -1,13 +1,17 @@ """ Hybrid Vision Transformer (ViT) in PyTorch -A PyTorch implement of the Hybrid Vision Transformers as described in +A PyTorch implement of the Hybrid Vision Transformers as described in: + 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 -NOTE This relies on code in vision_transformer.py. The hybrid model definitions were moved here to -keep file sizes sane. +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.TODO + +NOTE These hybrid model definitions depend on code in vision_transformer.py. +They were moved here to keep file sizes sane. -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ from copy import deepcopy from functools import partial