Some ViT cleanup, merge distilled model with main, fixup torchscript support for distilled models

pull/533/head
Ross Wightman 4 years ago
parent 0dfc5a66bb
commit a760a4c3f4

@ -268,12 +268,16 @@ class HybridEmbed(nn.Module):
class VisionTransformer(nn.Module): class VisionTransformer(nn.Module):
""" Vision Transformer """ Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
https://arxiv.org/abs/2010.11929 - https://arxiv.org/abs/2010.11929
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
weight_init=''):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -287,11 +291,13 @@ class VisionTransformer(nn.Module):
qkv_bias (bool): enable bias for qkv if True qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_rate (float): dropout rate drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate drop_path_rate (float): stochastic depth rate
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
norm_layer: (nn.Module): normalization layer norm_layer: (nn.Module): normalization layer
weight_init: (str): weight init scheme
""" """
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
@ -307,11 +313,13 @@ class VisionTransformer(nn.Module):
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if distilled else None
num_tokens = 2 if distilled else 1
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([ self.blocks = nn.Sequential(*[
Block( Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
@ -319,7 +327,7 @@ class VisionTransformer(nn.Module):
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
# Representation layer # Representation layer
if representation_size: if representation_size and not distilled:
self.num_features = representation_size self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([ self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)), ('fc', nn.Linear(embed_dim, representation_size)),
@ -328,11 +336,15 @@ class VisionTransformer(nn.Module):
else: else:
self.pre_logits = nn.Identity() self.pre_logits = nn.Identity()
# Classifier head # Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
if num_classes > 0 and distilled else nn.Identity()
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=.02)
if self.dist_token is not None:
trunc_normal_(self.dist_token, std=.02)
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
@ -346,91 +358,58 @@ class VisionTransformer(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {'pos_embed', 'cls_token'} return {'pos_embed', 'cls_token', 'dist_token'}
def get_classifier(self): def get_classifier(self):
return self.head if self.dist_token is None:
return self.head
else:
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=''): def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
if num_classes > 0 and self.dist_token is not None else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)[:, 0]
x = self.pre_logits(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
class DistilledVisionTransformer(VisionTransformer):
""" Vision Transformer with distillation token.
Paper: `Training data-efficient image transformers & distillation through attention` -
https://arxiv.org/abs/2012.12877
This impl of distilled ViT is taken from https://github.com/facebookresearch/deit
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
trunc_normal_(self.dist_token, std=.02)
trunc_normal_(self.pos_embed, std=.02)
self.head_dist.apply(self._init_weights)
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x) x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks if self.dist_token is None:
dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_token, x), dim=1)
x = torch.cat((cls_tokens, dist_token, x), dim=1) else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed x = self.pos_drop(x + self.pos_embed)
x = self.pos_drop(x) x = self.blocks(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x) x = self.norm(x)
return x[:, 0], x[:, 1] if self.dist_token is not None:
return x[:, 0], x[:, 1]
else:
return self.pre_logits(x[:, 0])
def forward(self, x): def forward(self, x):
x, x_dist = self.forward_features(x) x = self.forward_features(x)
x = self.head(x) if isinstance(x, tuple):
x_dist = self.head_dist(x_dist) x, x_dist = self.head(x[0]), self.head_dist(x[1])
if self.training: if self.training and not torch.jit.is_scripting():
return x, x_dist # during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else: else:
# during inference, return the average of both classifier predictions x = self.head(x)
return (x + x_dist) / 2 return x
def resize_pos_embed(posemb, posemb_new): def resize_pos_embed(posemb, posemb_new, token='class'):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from # Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
ntok_new = posemb_new.shape[1] ntok_new = posemb_new.shape[1]
if True: if token:
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] assert token in ('class', 'distill')
token_idx = 2 if token == 'distill' else 1
posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:]
ntok_new -= 1 ntok_new -= 1
else: else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0] posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
@ -457,12 +436,12 @@ def checkpoint_filter_fn(state_dict, model):
v = v.reshape(O, -1, H, W) v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape: elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights # To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(v, model.pos_embed) v = resize_pos_embed(v, model.pos_embed, token='distill' if model.dist_token is not None else 'class')
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs): def _create_vision_transformer(variant, pretrained=False, **kwargs):
default_cfg = deepcopy(default_cfgs[variant]) default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs) overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes'] default_num_classes = default_cfg['num_classes']
@ -480,9 +459,8 @@ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwa
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
model_cls = DistilledVisionTransformer if distilled else VisionTransformer
model = build_model_with_cfg( model = build_model_with_cfg(
model_cls, variant, pretrained, VisionTransformer, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfg,
img_size=img_size, img_size=img_size,
num_classes=num_classes, num_classes=num_classes,

Loading…
Cancel
Save