Move DeiT to own file, vit getting crowded. Working towards fixing #1029, make pooling interface for transformers and mlp closer to convnets. Still working through some details...

pull/1014/head
Ross Wightman 3 years ago
parent 95cfc9b3e8
commit 5f81d4de23

@ -205,22 +205,23 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
outputs = model.forward_features(input_tensor) outputs = model.forward_features(input_tensor)
if isinstance(outputs, (tuple, list)): if isinstance(outputs, (tuple, list)):
outputs = outputs[0] outputs = outputs[0]
assert outputs.shape[1] == model.num_features feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0) model.reset_classifier(0)
outputs = model.forward(input_tensor) outputs = model.forward(input_tensor)
if isinstance(outputs, (tuple, list)): if isinstance(outputs, (tuple, list)):
outputs = outputs[0] outputs = outputs[0]
assert len(outputs.shape) == 2 feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[1] == model.num_features assert outputs.shape[feat_dim] == model.num_features
model = create_model(model_name, pretrained=False, num_classes=0).eval() model = create_model(model_name, pretrained=False, num_classes=0).eval()
outputs = model.forward(input_tensor) outputs = model.forward(input_tensor)
if isinstance(outputs, (tuple, list)): if isinstance(outputs, (tuple, list)):
outputs = outputs[0] outputs = outputs[0]
assert len(outputs.shape) == 2 feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[1] == model.num_features assert outputs.shape[feat_dim] == model.num_features
# check classifier name matches default_cfg # check classifier name matches default_cfg
if cfg.get('num_classes', None): if cfg.get('num_classes', None):

@ -8,6 +8,7 @@ from .convmixer import *
from .convnext import * from .convnext import *
from .crossvit import * from .crossvit import *
from .cspnet import * from .cspnet import *
from .deit import *
from .densenet import * from .densenet import *
from .dla import * from .dla import *
from .dpn import * from .dpn import *

@ -232,13 +232,15 @@ class Beit(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage """ Vision Transformer with support for patch or hybrid CNN input stage
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, def __init__(
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=None, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_mean_pooling=True, init_scale=0.001): init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
head_init_scale=0.001):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
@ -247,10 +249,7 @@ class Beit(nn.Module):
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias: if use_shared_rel_pos_bias:
@ -266,8 +265,9 @@ class Beit(nn.Module):
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None) init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None)
for i in range(depth)]) for i in range(depth)])
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) use_fc_norm = self.global_pool == 'avg'
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else None
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights) self.apply(self._init_weights)
@ -278,8 +278,8 @@ class Beit(nn.Module):
self.fix_init_weight() self.fix_init_weight()
if isinstance(self.head, nn.Linear): if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02) trunc_normal_(self.head.weight, std=.02)
self.head.weight.data.mul_(init_scale) self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(init_scale) self.head.bias.data.mul_(head_init_scale)
def fix_init_weight(self): def fix_init_weight(self):
def rescale(param, layer_id): def rescale(param, layer_id):
@ -327,14 +327,15 @@ class Beit(nn.Module):
x = blk(x, rel_pos_bias=rel_pos_bias) x = blk(x, rel_pos_bias=rel_pos_bias)
x = self.norm(x) x = self.norm(x)
if self.fc_norm is not None: return x
t = x[:, 1:, :]
return self.fc_norm(t.mean(1))
else:
return x[:, 0]
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
if self.fc_norm is not None:
x = x[:, 1:].mean(dim=1)
x = self.fc_norm(x)
else:
x = x[:, 0]
x = self.head(x) x = self.head(x)
return x return x

@ -213,11 +213,11 @@ class Cait(nn.Module):
act_layer=nn.GELU, act_layer=nn.GELU,
attn_block=TalkingHeadAttn, attn_block=TalkingHeadAttn,
mlp_block=Mlp, mlp_block=Mlp,
init_scale=1e-4, init_values=1e-4,
attn_block_token_only=ClassAttn, attn_block_token_only=ClassAttn,
mlp_block_token_only=Mlp, mlp_block_token_only=Mlp,
depth_token_only=2, depth_token_only=2,
mlp_ratio_clstk=4.0 mlp_ratio_token_only=4.0
): ):
super().__init__() super().__init__()
@ -234,19 +234,19 @@ class Cait(nn.Module):
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [drop_path_rate for i in range(depth)] dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([ self.blocks = nn.Sequential(*[
block_layers( block_layers(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale) act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_values)
for i in range(depth)]) for i in range(depth)])
self.blocks_token_only = nn.ModuleList([ self.blocks_token_only = nn.ModuleList([
block_layers_token( block_layers_token(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_token_only, qkv_bias=qkv_bias,
drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
act_layer=act_layer, attn_block=attn_block_token_only, act_layer=act_layer, attn_block=attn_block_token_only,
mlp_block=mlp_block_token_only, init_values=init_scale) mlp_block=mlp_block_token_only, init_values=init_values)
for i in range(depth_token_only)]) for i in range(depth_token_only)])
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
@ -281,25 +281,21 @@ class Cait(nn.Module):
def forward_features(self, x): def forward_features(self, x):
B = x.shape[0] B = x.shape[0]
x = self.patch_embed(x) x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = x + self.pos_embed x = x + self.pos_embed
x = self.pos_drop(x) x = self.pos_drop(x)
x = self.blocks(x)
for i, blk in enumerate(self.blocks): cls_tokens = self.cls_token.expand(B, -1, -1)
x = blk(x)
for i, blk in enumerate(self.blocks_token_only): for i, blk in enumerate(self.blocks_token_only):
cls_tokens = blk(x, cls_tokens) cls_tokens = blk(x, cls_tokens)
x = torch.cat((cls_tokens, x), dim=1) x = torch.cat((cls_tokens, x), dim=1)
x = self.norm(x) x = self.norm(x)
return x[:, 0] return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = x[:, 0]
x = self.head(x) x = self.head(x)
return x return x
@ -326,69 +322,69 @@ def _create_cait(variant, pretrained=False, **kwargs):
@register_model @register_model
def cait_xxs24_224(pretrained=False, **kwargs): def cait_xxs24_224(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs) model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs24_224', pretrained=pretrained, **model_args) model = _create_cait('cait_xxs24_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def cait_xxs24_384(pretrained=False, **kwargs): def cait_xxs24_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs) model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs24_384', pretrained=pretrained, **model_args) model = _create_cait('cait_xxs24_384', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def cait_xxs36_224(pretrained=False, **kwargs): def cait_xxs36_224(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs) model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs36_224', pretrained=pretrained, **model_args) model = _create_cait('cait_xxs36_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def cait_xxs36_384(pretrained=False, **kwargs): def cait_xxs36_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs) model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs36_384', pretrained=pretrained, **model_args) model = _create_cait('cait_xxs36_384', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def cait_xs24_384(pretrained=False, **kwargs): def cait_xs24_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_scale=1e-5, **kwargs) model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_values=1e-5, **kwargs)
model = _create_cait('cait_xs24_384', pretrained=pretrained, **model_args) model = _create_cait('cait_xs24_384', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def cait_s24_224(pretrained=False, **kwargs): def cait_s24_224(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs) model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs)
model = _create_cait('cait_s24_224', pretrained=pretrained, **model_args) model = _create_cait('cait_s24_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def cait_s24_384(pretrained=False, **kwargs): def cait_s24_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs) model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs)
model = _create_cait('cait_s24_384', pretrained=pretrained, **model_args) model = _create_cait('cait_s24_384', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def cait_s36_384(pretrained=False, **kwargs): def cait_s36_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_scale=1e-6, **kwargs) model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_values=1e-6, **kwargs)
model = _create_cait('cait_s36_384', pretrained=pretrained, **model_args) model = _create_cait('cait_s36_384', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def cait_m36_384(pretrained=False, **kwargs): def cait_m36_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_scale=1e-6, **kwargs) model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_values=1e-6, **kwargs)
model = _create_cait('cait_m36_384', pretrained=pretrained, **model_args) model = _create_cait('cait_m36_384', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def cait_m48_448(pretrained=False, **kwargs): def cait_m48_448(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_scale=1e-6, **kwargs) model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_values=1e-6, **kwargs)
model = _create_cait('cait_m48_448', pretrained=pretrained, **model_args) model = _create_cait('cait_m48_448', pretrained=pretrained, **model_args)
return model return model

@ -447,6 +447,7 @@ class CoaT(nn.Module):
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
else: else:
# CoaT-Lite series: Use feature of last scale for classification. # CoaT-Lite series: Use feature of last scale for classification.
self.aggregate = None
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# Initialize weights. # Initialize weights.
@ -542,8 +543,7 @@ class CoaT(nn.Module):
else: else:
# Return features for classification. # Return features for classification.
x4 = self.norm4(x4) x4 = self.norm4(x4)
x4_cls = x4[:, 0] return x4
return x4_cls
# Parallel blocks. # Parallel blocks.
for blk in self.parallel_blocks: for blk in self.parallel_blocks:
@ -574,20 +574,20 @@ class CoaT(nn.Module):
x2 = self.norm2(x2) x2 = self.norm2(x2)
x3 = self.norm3(x3) x3 = self.norm3(x3)
x4 = self.norm4(x4) x4 = self.norm4(x4)
x2_cls = x2[:, :1] # [B, 1, C] return [x2, x3, x4]
x3_cls = x3[:, :1]
x4_cls = x4[:, :1] def forward(self, x) -> torch.Tensor:
merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # [B, 3, C] if not torch.jit.is_scripting() and self.return_interm_layers:
merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C]
return merged_cls
def forward(self, x):
if self.return_interm_layers:
# Return intermediate features (for down-stream tasks). # Return intermediate features (for down-stream tasks).
return self.forward_features(x) return self.forward_features(x)
else: else:
# Return features for classification. # Return features for classification.
x = self.forward_features(x) x_feat = self.forward_features(x)
if isinstance(x_feat, (tuple, list)):
x = torch.cat([xl[:, :1] for xl in x_feat], dim=1) # [B, 3, C]
x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C]
else:
x = x_feat[:, 0]
x = self.head(x) x = self.head(x)
return x return x

@ -308,10 +308,11 @@ class ConViT(nn.Module):
x = blk(x) x = blk(x)
x = self.norm(x) x = self.norm(x)
return x[:, 0] return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = x[:, 0]
x = self.head(x) x = self.head(x)
return x return x

@ -69,13 +69,12 @@ class ConvMixer(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.blocks(x) x = self.blocks(x)
x = self.pooling(x)
return x return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = self.pooling(x)
x = self.head(x) x = self.head(x)
return x return x

@ -319,7 +319,6 @@ def checkpoint_filter_fn(state_dict, model):
def _create_convnext(variant, pretrained=False, **kwargs): def _create_convnext(variant, pretrained=False, **kwargs):
model = build_model_with_cfg( model = build_model_with_cfg(
ConvNeXt, variant, pretrained, ConvNeXt, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs) **kwargs)

@ -368,7 +368,7 @@ class CrossViT(nn.Module):
[nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in [nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
range(self.num_branches)]) range(self.num_branches)])
def forward_features(self, x): def forward_features(self, x) -> List[torch.Tensor]:
B = x.shape[0] B = x.shape[0]
xs = [] xs = []
for i, patch_embed in enumerate(self.patch_embed): for i, patch_embed in enumerate(self.patch_embed):
@ -389,11 +389,11 @@ class CrossViT(nn.Module):
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm # NOTE: was before branch token section, move to here to assure all branch token are before layer norm
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)] xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
return [xo[:, 0] for xo in xs] return xs
def forward(self, x): def forward(self, x):
xs = self.forward_features(x) xs = self.forward_features(x)
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)] ce_logits = [head(xs[i][:, 0]) for i, head in enumerate(self.head)]
if not isinstance(self.head[0], nn.Identity): if not isinstance(self.head[0], nn.Identity):
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0) ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
return ce_logits return ce_logits

@ -0,0 +1,201 @@
""" DeiT - Data-efficient Image Transformers
DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
Modifications copyright 2021, Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
from .helpers import build_model_with_cfg
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
# deit models (FB weights)
'deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
'deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
'deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'),
'deit_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
classifier=('head', 'head_dist')),
'deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0,
classifier=('head', 'head_dist')),
}
class VisionTransformerDistilled(VisionTransformer):
""" Vision Transformer w/ Distillation Token and Head
Distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877
"""
def __init__(self, *args, **kwargs):
weight_init = kwargs.pop('weight_init', '')
super().__init__(*args, **kwargs, weight_init='skip')
self.num_tokens = 2
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.init_weights(weight_init)
def init_weights(self, mode=''):
trunc_normal_(self.dist_token, std=.02)
super().init_weights(mode=mode)
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x) -> torch.Tensor:
x = self.patch_embed(x)
x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
x_dist = self.head_dist(x[:, 1])
x = self.head(x[:, 0])
if self.training and not torch.jit.is_scripting():
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
model = build_model_with_cfg(
model_cls, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_deit('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_small_patch16_224(pretrained=False, **kwargs):
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_base_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_deit('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_base_patch16_384(pretrained=False, **kwargs):
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_deit('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_deit(
'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_deit(
'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_deit(
'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_deit(
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
return model

@ -290,10 +290,10 @@ class Attention(nn.Module):
qkv = self.qkv(x) qkv = self.qkv(x)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
q = q.permute(0, 2, 1, 3) q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3) k = k.permute(0, 2, 3, 1)
v = v.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3)
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = q @ k * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
@ -383,11 +383,11 @@ class AttentionSubsample(nn.Module):
else: else:
B, N, C = x.shape B, N, C = x.shape
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3)
k = k.permute(0, 2, 1, 3) # BHNC k = k.permute(0, 2, 3, 1) # BHCN
v = v.permute(0, 2, 1, 3) # BHNC v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = q @ k * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
@ -519,11 +519,11 @@ class Levit(nn.Module):
if not self.use_conv: if not self.use_conv:
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
x = self.blocks(x) x = self.blocks(x)
x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
return x return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
if self.head_dist is not None: if self.head_dist is not None:
x, x_dist = self.head(x), self.head_dist(x) x, x_dist = self.head(x), self.head_dist(x)
if self.training and not torch.jit.is_scripting(): if self.training and not torch.jit.is_scripting():

@ -294,11 +294,11 @@ class MlpMixer(nn.Module):
x = self.stem(x) x = self.stem(x)
x = self.blocks(x) x = self.blocks(x)
x = self.norm(x) x = self.norm(x)
x = x.mean(dim=1)
return x return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = x.mean(dim=1)
x = self.head(x) x = self.head(x)
return x return x

@ -200,7 +200,8 @@ class MobileNetV3Features(nn.Module):
and object detection models. and object detection models.
""" """
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, def __init__(
self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels, stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
super(MobileNetV3Features, self).__init__() super(MobileNetV3Features, self).__init__()

@ -125,10 +125,8 @@ class ConvHeadPooling(nn.Module):
self.fc = nn.Linear(in_feature, out_feature) self.fc = nn.Linear(in_feature, out_feature)
def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.conv(x) x = self.conv(x)
cls_token = self.fc(cls_token) cls_token = self.fc(cls_token)
return x, cls_token return x, cls_token
@ -225,21 +223,18 @@ class PoolingVisionTransformer(nn.Module):
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x, cls_tokens = self.transformers((x, cls_tokens)) x, cls_tokens = self.transformers((x, cls_tokens))
cls_tokens = self.norm(cls_tokens) cls_tokens = self.norm(cls_tokens)
if self.head_dist is not None: return cls_tokens
return cls_tokens[:, 0], cls_tokens[:, 1]
else:
return cls_tokens[:, 0]
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
if self.head_dist is not None: if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) # x must be a tuple
if self.training and not torch.jit.is_scripting(): if self.training and not torch.jit.is_scripting():
return x, x_dist return x, x_dist
else: else:
return (x + x_dist) / 2 return (x + x_dist) / 2
else: else:
return self.head(x) return self.head(x[:, 0])
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):

@ -14,7 +14,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# -------------------------------------------------------- # --------------------------------------------------------
import logging import logging
import math import math
from copy import deepcopy from functools import partial
from typing import Optional from typing import Optional
import torch import torch
@ -23,9 +23,8 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
from .layers import _assert
from .registry import register_model from .registry import register_model
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
@ -444,15 +443,17 @@ class SwinTransformer(nn.Module):
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
""" """
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, def __init__(
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
window_size=7, mlp_ratio=4., qkv_bias=True, window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, weight_init='', **kwargs): use_checkpoint=False, weight_init='', **kwargs):
super().__init__() super().__init__()
assert global_pool in ('', 'avg')
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool
self.num_layers = len(depths) self.num_layers = len(depths)
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.ape = ape self.ape = ape
@ -468,18 +469,11 @@ class SwinTransformer(nn.Module):
self.patch_grid = self.patch_embed.grid_size self.patch_grid = self.patch_embed.grid_size
# absolute position embedding # absolute position embedding
if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) if ape else None
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
else:
self.absolute_pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers # build layers
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
layers = [] layers = []
for i_layer in range(self.num_layers): for i_layer in range(self.num_layers):
layers += [BasicLayer( layers += [BasicLayer(
@ -500,16 +494,16 @@ class SwinTransformer(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.norm = norm_layer(self.num_features) self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') self.init_weights(weight_init)
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
if weight_init.startswith('jax'): def init_weights(self, mode=''):
for n, m in self.named_modules(): assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) if self.absolute_pos_embed is not None:
else: trunc_normal_(self.absolute_pos_embed, std=.02)
self.apply(_init_vit_weights) head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl='jax' in mode), self)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
@ -522,8 +516,9 @@ class SwinTransformer(nn.Module):
def get_classifier(self): def get_classifier(self):
return self.head return self.head
def reset_classifier(self, num_classes, global_pool=''): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
@ -533,12 +528,12 @@ class SwinTransformer(nn.Module):
x = self.pos_drop(x) x = self.pos_drop(x)
x = self.layers(x) x = self.layers(x)
x = self.norm(x) # B L C x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
if self.global_pool == 'avg':
x = x.mean(dim=1)
x = self.head(x) x = self.head(x)
return x return x

@ -226,10 +226,11 @@ class TNT(nn.Module):
pixel_embed, patch_embed = blk(pixel_embed, patch_embed) pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
patch_embed = self.norm(patch_embed) patch_embed = self.norm(patch_embed)
return patch_embed[:, 0] return patch_embed
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = x[:, 0]
x = self.head(x) x = self.head(x)
return x return x

@ -357,10 +357,11 @@ class Twins(nn.Module):
if i < len(self.depths) - 1: if i < len(self.depths) - 1:
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
x = self.norm(x) x = self.norm(x)
return x.mean(dim=1) # GAP here return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = x.mean(dim=1)
x = self.head(x) x = self.head(x)
return x return x

@ -10,9 +10,6 @@ A PyTorch implement of Vision Transformers as described in:
The official jax code is released and available at https://github.com/google-research/vision_transformer The official jax code is released and available at https://github.com/google-research/vision_transformer
DeiT model defs and weights from https://github.com/facebookresearch/deit,
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
Acknowledgments: Acknowledgments:
* The paper authors for releasing code and weights, thanks! * The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
@ -26,7 +23,6 @@ import math
import logging import logging
from functools import partial from functools import partial
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -105,6 +101,7 @@ default_cfgs = {
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch14_224': _cfg(url=''),
'vit_huge_patch14_224': _cfg(url=''), 'vit_huge_patch14_224': _cfg(url=''),
'vit_giant_patch14_224': _cfg(url=''), 'vit_giant_patch14_224': _cfg(url=''),
'vit_gigantic_patch14_224': _cfg(url=''), 'vit_gigantic_patch14_224': _cfg(url=''),
@ -161,32 +158,6 @@ default_cfgs = {
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
# deit models (FB weights)
'deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
'deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
classifier=('head', 'head_dist')),
# ViT ImageNet-21K-P pretraining by MILL # ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil_in21k': _cfg( 'vit_base_patch16_224_miil_in21k': _cfg(
@ -253,15 +224,13 @@ class VisionTransformer(nn.Module):
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929 - https://arxiv.org/abs/2010.11929
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, def __init__(
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, global_pool='',
act_layer=None, weight_init=''): drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
embed_layer=PatchEmbed, norm_layer=None, act_layer=None):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -274,18 +243,19 @@ class VisionTransformer(nn.Module):
mlp_ratio (int): ratio of mlp hidden dim to embedding dim mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models weight_init: (str): weight init scheme
drop_rate (float): dropout rate drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate drop_path_rate (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer norm_layer: (nn.Module): normalization layer
weight_init: (str): weight init scheme act_layer: (nn.Module): MLP activation layer
""" """
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1 self.num_tokens = 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU act_layer = act_layer or nn.GELU
@ -294,7 +264,6 @@ class VisionTransformer(nn.Module):
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
@ -304,38 +273,41 @@ class VisionTransformer(nn.Module):
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)]) for i in range(depth)])
self.norm = norm_layer(embed_dim) use_fc_norm = self.global_pool == 'avg'
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Representation layer. Used for original ViT models w/ in21k pretraining.
self.representation_size = representation_size
self.pre_logits = nn.Identity()
if representation_size:
self._reset_representation(representation_size)
# Classifier Head
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
final_chs = self.representation_size if self.representation_size else self.embed_dim
self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
# Representation layer if weight_init != 'skip':
if representation_size and not distilled: self.init_weights(weight_init)
self.num_features = representation_size
def _reset_representation(self, representation_size):
self.representation_size = representation_size
if self.representation_size:
self.pre_logits = nn.Sequential(OrderedDict([ self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)), ('fc', nn.Linear(self.embed_dim, self.representation_size)),
('act', nn.Tanh()) ('act', nn.Tanh())
])) ]))
else: else:
self.pre_logits = nn.Identity() self.pre_logits = nn.Identity()
# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(weight_init)
def init_weights(self, mode=''): def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', '') assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
if self.dist_token is not None: if 'jax' not in mode:
trunc_normal_(self.dist_token, std=.02) # init cls token to truncated normal if not following jax impl, jax impl is zero
if mode.startswith('jax'):
# leave cls token as zeros to match jax impl
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
else:
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=.02)
self.apply(_init_vit_weights) named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl='jax' in mode), self)
def _init_weights(self, m): def _init_weights(self, m):
# this fn left here for compat with downstream users # this fn left here for compat with downstream users
@ -350,42 +322,32 @@ class VisionTransformer(nn.Module):
return {'pos_embed', 'cls_token', 'dist_token'} return {'pos_embed', 'cls_token', 'dist_token'}
def get_classifier(self): def get_classifier(self):
if self.dist_token is None:
return self.head return self.head
else:
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=''): def reset_classifier(self, num_classes, global_pool='', representation_size=None):
self.num_classes = num_classes self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.global_pool = global_pool
if self.num_tokens == 2: if representation_size is not None:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() self._reset_representation(representation_size)
final_chs = self.representation_size if self.representation_size else self.embed_dim
self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed) x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x) x = self.blocks(x)
x = self.norm(x) x = self.norm(x)
if self.dist_token is None: return x
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
if self.head_dist is not None: if self.global_pool == 'avg':
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple x = x[:, self.num_tokens:].mean(dim=1)
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else: else:
x = x[:, 0]
x = self.fc_norm(x)
x = self.pre_logits(x)
x = self.head(x) x = self.head(x)
return x return x
@ -708,7 +670,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
@register_model @register_model
def vit_large_patch16_224(pretrained=False, **kwargs): def vit_large_patch16_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
@ -726,6 +688,15 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
return model return model
@register_model
def vit_large_patch14_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14)
"""
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_huge_patch14_224(pretrained=False, **kwargs): def vit_huge_patch14_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
@ -914,90 +885,6 @@ def vit_base_patch8_224_dino(pretrained=False, **kwargs):
return model return model
@register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_small_patch16_224(pretrained=False, **kwargs):
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_base_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_base_patch16_384(pretrained=False, **kwargs):
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer(
'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer(
'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model @register_model
def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).

@ -426,17 +426,17 @@ class XCiT(nn.Module):
for blk in self.blocks: for blk in self.blocks:
x = blk(x, Hp, Wp) x = blk(x, Hp, Wp)
cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.cls_attn_blocks: for blk in self.cls_attn_blocks:
x = blk(x) x = blk(x)
x = self.norm(x)[:, 0] x = self.norm(x)
return x return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = x[:, 0]
x = self.head(x) x = self.head(x)
return x return x

Loading…
Cancel
Save