Move DeiT to own file, vit getting crowded. Working towards fixing #1029, make pooling interface for transformers and mlp closer to convnets. Still working through some details...
parent
95cfc9b3e8
commit
5f81d4de23
@ -0,0 +1,201 @@
|
||||
""" DeiT - Data-efficient Image Transformers
|
||||
|
||||
DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
|
||||
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
||||
|
||||
Modifications copyright 2021, Ross Wightman
|
||||
"""
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
|
||||
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# deit models (FB weights)
|
||||
'deit_tiny_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
||||
'deit_small_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
||||
'deit_base_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'),
|
||||
'deit_base_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'deit_tiny_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'deit_small_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'deit_base_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'deit_base_distilled_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
classifier=('head', 'head_dist')),
|
||||
}
|
||||
|
||||
|
||||
class VisionTransformerDistilled(VisionTransformer):
|
||||
""" Vision Transformer w/ Distillation Token and Head
|
||||
|
||||
Distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
||||
- https://arxiv.org/abs/2012.12877
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
weight_init = kwargs.pop('weight_init', '')
|
||||
super().__init__(*args, **kwargs, weight_init='skip')
|
||||
self.num_tokens = 2
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||
|
||||
self.init_weights(weight_init)
|
||||
|
||||
def init_weights(self, mode=''):
|
||||
trunc_normal_(self.dist_token, std=.02)
|
||||
super().init_weights(mode=mode)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((
|
||||
self.cls_token.expand(x.shape[0], -1, -1),
|
||||
self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x_dist = self.head_dist(x[:, 1])
|
||||
x = self.head(x[:, 0])
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
return x, x_dist
|
||||
else:
|
||||
# during inference, return the average of both classifier predictions
|
||||
return (x + x_dist) / 2
|
||||
|
||||
|
||||
def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_deit('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_deit('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_deit('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_deit(
|
||||
'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_deit(
|
||||
'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_deit(
|
||||
'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_deit(
|
||||
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
Loading…
Reference in new issue