""" DeiT - Data-efficient Image Transformers DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 Modifications copyright 2021, Ross Wightman """ # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. import torch from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn from .helpers import build_model_with_cfg from .registry import register_model def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } default_cfgs = { # deit models (FB weights) 'deit_tiny_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), 'deit_small_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), 'deit_base_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'), 'deit_base_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', input_size=(3, 384, 384), crop_pct=1.0), 'deit_tiny_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', classifier=('head', 'head_dist')), 'deit_small_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', classifier=('head', 'head_dist')), 'deit_base_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', classifier=('head', 'head_dist')), 'deit_base_distilled_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), } class VisionTransformerDistilled(VisionTransformer): """ Vision Transformer w/ Distillation Token and Head Distillation token & head support for `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 """ def __init__(self, *args, **kwargs): weight_init = kwargs.pop('weight_init', '') super().__init__(*args, **kwargs, weight_init='skip') self.num_tokens = 2 self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim)) self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() self.init_weights(weight_init) def init_weights(self, mode=''): trunc_normal_(self.dist_token, std=.02) super().init_weights(mode=mode) def get_classifier(self): return self.head, self.head_dist def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x) -> torch.Tensor: x = self.patch_embed(x) x = torch.cat(( self.cls_token.expand(x.shape[0], -1, -1), self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) x = self.pos_drop(x + self.pos_embed) x = self.blocks(x) x = self.norm(x) return x def forward(self, x): x = self.forward_features(x) x_dist = self.head_dist(x[:, 1]) x = self.head(x[:, 0]) if self.training and not torch.jit.is_scripting(): return x, x_dist else: # during inference, return the average of both classifier predictions return (x + x_dist) / 2 def _create_deit(variant, pretrained=False, distilled=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model_cls = VisionTransformerDistilled if distilled else VisionTransformer model = build_model_with_cfg( model_cls, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model @register_model def deit_tiny_patch16_224(pretrained=False, **kwargs): """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_deit('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def deit_small_patch16_224(pretrained=False, **kwargs): """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def deit_base_patch16_224(pretrained=False, **kwargs): """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_deit('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def deit_base_patch16_384(pretrained=False, **kwargs): """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_deit('deit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_deit( 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) return model @register_model def deit_small_distilled_patch16_224(pretrained=False, **kwargs): """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_deit( 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) return model @register_model def deit_base_distilled_patch16_224(pretrained=False, **kwargs): """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_deit( 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) return model @register_model def deit_base_distilled_patch16_384(pretrained=False, **kwargs): """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_deit( 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) return model