@ -1,7 +1,8 @@
""" MobileViT
Paper :
` MobileViT : Light - weight , General - purpose , and Mobile - friendly Vision Transformer ` - https : / / arxiv . org / abs / 2110.02178
V1 : ` MobileViT : Light - weight , General - purpose , and Mobile - friendly Vision Transformer ` - https : / / arxiv . org / abs / 2110.02178
V2 : ` Separable Self - attention for Mobile Vision Transformers ` - https : / / arxiv . org / abs / 2206.02680
MobileVitBlock and checkpoints adapted from https : / / github . com / apple / ml - cvnets ( original copyright below )
License : https : / / github . com / apple / ml - cvnets / blob / main / LICENSE ( Apple open source )
@ -13,7 +14,7 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#
import math
from typing import Union , Callable , Dict , Tuple , Optional
from typing import Union , Callable , Dict , Tuple , Optional , Sequence
import torch
from torch import nn
@ -21,7 +22,7 @@ import torch.nn.functional as F
from . byobnet import register_block , ByoBlockCfg , ByoModelCfg , ByobNet , LayerFn , num_groups
from . fx_features import register_notrace_module
from . layers import to_2tuple , make_divisible
from . layers import to_2tuple , make_divisible , LayerNorm2d , GroupNorm1 , ConvMlp , DropPath
from . vision_transformer import Block as TransformerBlock
from . helpers import build_model_with_cfg
from . registry import register_model
@ -48,6 +49,48 @@ default_cfgs = {
' mobilevit_s ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth ' ) ,
' semobilevit_s ' : _cfg ( ) ,
' mobilevitv2_050 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_050-49951ee2.pth ' ,
crop_pct = 0.888 ) ,
' mobilevitv2_075 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_075-b5556ef6.pth ' ,
crop_pct = 0.888 ) ,
' mobilevitv2_100 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_100-e464ef3b.pth ' ,
crop_pct = 0.888 ) ,
' mobilevitv2_125 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_125-0ae35027.pth ' ,
crop_pct = 0.888 ) ,
' mobilevitv2_150 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150-737c5019.pth ' ,
crop_pct = 0.888 ) ,
' mobilevitv2_175 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175-16462ee2.pth ' ,
crop_pct = 0.888 ) ,
' mobilevitv2_200 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200-b3422f67.pth ' ,
crop_pct = 0.888 ) ,
' mobilevitv2_150_in22ft1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_in22ft1k-0b555d7b.pth ' ,
crop_pct = 0.888 ) ,
' mobilevitv2_175_in22ft1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_in22ft1k-4117fa1f.pth ' ,
crop_pct = 0.888 ) ,
' mobilevitv2_200_in22ft1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_in22ft1k-1d7c8927.pth ' ,
crop_pct = 0.888 ) ,
' mobilevitv2_150_384_in22ft1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_384_in22ft1k-9e142854.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 ) ,
' mobilevitv2_175_384_in22ft1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_384_in22ft1k-059cbe56.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 ) ,
' mobilevitv2_200_384_in22ft1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_384_in22ft1k-32c87503.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 ) ,
}
@ -72,6 +115,40 @@ def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4,
)
def _mobilevitv2_block ( d , c , s , transformer_depth , patch_size = 2 , br = 2.0 , transformer_br = 0.5 ) :
# inverted residual + mobilevit blocks as per MobileViT network
return (
_inverted_residual_block ( d = d , c = c , s = s , br = br ) ,
ByoBlockCfg (
type = ' mobilevit2 ' , d = 1 , c = c , s = 1 , br = transformer_br , gs = 1 ,
block_kwargs = dict (
transformer_depth = transformer_depth ,
patch_size = patch_size )
)
)
def _mobilevitv2_cfg ( multiplier = 1.0 ) :
chs = ( 64 , 128 , 256 , 384 , 512 )
if multiplier != 1.0 :
chs = tuple ( [ int ( c * multiplier ) for c in chs ] )
cfg = ByoModelCfg (
blocks = (
_inverted_residual_block ( d = 1 , c = chs [ 0 ] , s = 1 , br = 2.0 ) ,
_inverted_residual_block ( d = 2 , c = chs [ 1 ] , s = 2 , br = 2.0 ) ,
_mobilevitv2_block ( d = 1 , c = chs [ 2 ] , s = 2 , transformer_depth = 2 ) ,
_mobilevitv2_block ( d = 1 , c = chs [ 3 ] , s = 2 , transformer_depth = 4 ) ,
_mobilevitv2_block ( d = 1 , c = chs [ 4 ] , s = 2 , transformer_depth = 3 ) ,
) ,
stem_chs = int ( 32 * multiplier ) ,
stem_type = ' 3x3 ' ,
stem_pool = ' ' ,
downsample = ' ' ,
act_layer = ' silu ' ,
)
return cfg
model_cfgs = dict (
mobilevit_xxs = ByoModelCfg (
blocks = (
@ -137,11 +214,19 @@ model_cfgs = dict(
attn_kwargs = dict ( rd_ratio = 1 / 8 ) ,
num_features = 640 ,
) ,
mobilevitv2_050 = _mobilevitv2_cfg ( .50 ) ,
mobilevitv2_075 = _mobilevitv2_cfg ( .75 ) ,
mobilevitv2_125 = _mobilevitv2_cfg ( 1.25 ) ,
mobilevitv2_100 = _mobilevitv2_cfg ( 1.0 ) ,
mobilevitv2_150 = _mobilevitv2_cfg ( 1.5 ) ,
mobilevitv2_175 = _mobilevitv2_cfg ( 1.75 ) ,
mobilevitv2_200 = _mobilevitv2_cfg ( 2.0 ) ,
)
@register_notrace_module
class MobileViTBlock ( nn . Module ) :
class MobileVi t Block( nn . Module ) :
""" MobileViT block
Paper : https : / / arxiv . org / abs / 2110.02178 ? context = cs . LG
"""
@ -165,9 +250,9 @@ class MobileViTBlock(nn.Module):
drop_path_rate : float = 0. ,
layers : LayerFn = None ,
transformer_norm_layer : Callable = nn . LayerNorm ,
downsample : str = ' '
* * kwargs , # eat unused args
) :
super ( MobileVi T Block, self ) . __init__ ( )
super ( MobileVi t Block, self ) . __init__ ( )
layers = layers or LayerFn ( )
groups = num_groups ( group_size , in_chs )
@ -241,7 +326,270 @@ class MobileViTBlock(nn.Module):
return x
register_block ( ' mobilevit ' , MobileViTBlock )
class LinearSelfAttention ( nn . Module ) :
"""
This layer applies a self - attention with linear complexity , as described in ` https : / / arxiv . org / abs / 2206.02680 `
This layer can be used for self - as well as cross - attention .
Args :
embed_dim ( int ) : : math : ` C ` from an expected input of size : math : ` ( N , C , H , W ) `
attn_drop ( float ) : Dropout value for context scores . Default : 0.0
bias ( bool ) : Use bias in learnable layers . Default : True
Shape :
- Input : : math : ` ( N , C , P , N ) ` where : math : ` N ` is the batch size , : math : ` C ` is the input channels ,
: math : ` P ` is the number of pixels in the patch , and : math : ` N ` is the number of patches
- Output : same as the input
. . note : :
For MobileViTv2 , we unfold the feature map [ B , C , H , W ] into [ B , C , P , N ] where P is the number of pixels
in a patch and N is the number of patches . Because channel is the first dimension in this unfolded tensor ,
we use point - wise convolution ( instead of a linear layer ) . This avoids a transpose operation ( which may be
expensive on resource - constrained devices ) that may be required to convert the unfolded tensor from
channel - first to channel - last format in case of a linear layer .
"""
def __init__ (
self ,
embed_dim : int ,
attn_drop : float = 0.0 ,
proj_drop : float = 0.0 ,
bias : bool = True ,
) - > None :
super ( ) . __init__ ( )
self . embed_dim = embed_dim
self . qkv_proj = nn . Conv2d (
in_channels = embed_dim ,
out_channels = 1 + ( 2 * embed_dim ) ,
bias = bias ,
kernel_size = 1 ,
)
self . attn_drop = nn . Dropout ( attn_drop )
self . out_proj = nn . Conv2d (
in_channels = embed_dim ,
out_channels = embed_dim ,
bias = bias ,
kernel_size = 1 ,
)
self . out_drop = nn . Dropout ( proj_drop )
def _forward_self_attn ( self , x : torch . Tensor ) - > torch . Tensor :
# [B, C, P, N] --> [B, h + 2d, P, N]
qkv = self . qkv_proj ( x )
# Project x into query, key and value
# Query --> [B, 1, P, N]
# value, key --> [B, d, P, N]
query , key , value = qkv . split ( [ 1 , self . embed_dim , self . embed_dim ] , dim = 1 )
# apply softmax along N dimension
context_scores = F . softmax ( query , dim = - 1 )
context_scores = self . attn_drop ( context_scores )
# Compute context vector
# [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1]
context_vector = ( key * context_scores ) . sum ( dim = - 1 , keepdim = True )
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F . relu ( value ) * context_vector . expand_as ( value )
out = self . out_proj ( out )
out = self . out_drop ( out )
return out
@torch.jit.ignore ( )
def _forward_cross_attn ( self , x : torch . Tensor , x_prev : Optional [ torch . Tensor ] = None ) - > torch . Tensor :
# x --> [B, C, P, N]
# x_prev = [B, C, P, M]
batch_size , in_dim , kv_patch_area , kv_num_patches = x . shape
q_patch_area , q_num_patches = x . shape [ - 2 : ]
assert (
kv_patch_area == q_patch_area
) , " The number of pixels in a patch for query and key_value should be the same "
# compute query, key, and value
# [B, C, P, M] --> [B, 1 + d, P, M]
qk = F . conv2d (
x_prev ,
weight = self . qkv_proj . weight [ : self . embed_dim + 1 ] ,
bias = self . qkv_proj . bias [ : self . embed_dim + 1 ] ,
)
# [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M]
query , key = qk . split ( [ 1 , self . embed_dim ] , dim = 1 )
# [B, C, P, N] --> [B, d, P, N]
value = F . conv2d (
x ,
weight = self . qkv_proj . weight [ self . embed_dim + 1 ] ,
bias = self . qkv_proj . bias [ self . embed_dim + 1 ] if self . qkv_proj . bias is not None else None ,
)
# apply softmax along M dimension
context_scores = F . softmax ( query , dim = - 1 )
context_scores = self . attn_drop ( context_scores )
# compute context vector
# [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1]
context_vector = ( key * context_scores ) . sum ( dim = - 1 , keepdim = True )
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F . relu ( value ) * context_vector . expand_as ( value )
out = self . out_proj ( out )
out = self . out_drop ( out )
return out
def forward ( self , x : torch . Tensor , x_prev : Optional [ torch . Tensor ] = None ) - > torch . Tensor :
if x_prev is None :
return self . _forward_self_attn ( x )
else :
return self . _forward_cross_attn ( x , x_prev = x_prev )
class LinearTransformerBlock ( nn . Module ) :
"""
This class defines the pre - norm transformer encoder with linear self - attention in ` MobileViTv2 paper < > ` _
Args :
embed_dim ( int ) : : math : ` C_ { in } ` from an expected input of size : math : ` ( B , C_ { in } , P , N ) `
mlp_ratio ( float ) : Inner dimension ratio of the FFN relative to embed_dim
drop ( float ) : Dropout rate . Default : 0.0
attn_drop ( float ) : Dropout rate for attention in multi - head attention . Default : 0.0
drop_path ( float ) : Stochastic depth rate Default : 0.0
norm_layer ( Callable ) : Normalization layer . Default : layer_norm_2d
Shape :
- Input : : math : ` ( B , C_ { in } , P , N ) ` where : math : ` B ` is batch size , : math : ` C_ { in } ` is input embedding dim ,
: math : ` P ` is number of pixels in a patch , and : math : ` N ` is number of patches ,
- Output : same shape as the input
"""
def __init__ (
self ,
embed_dim : int ,
mlp_ratio : float = 2.0 ,
drop : float = 0.0 ,
attn_drop : float = 0.0 ,
drop_path : float = 0.0 ,
act_layer = None ,
norm_layer = None ,
) - > None :
super ( ) . __init__ ( )
act_layer = act_layer or nn . SiLU
norm_layer = norm_layer or GroupNorm1
self . norm1 = norm_layer ( embed_dim )
self . attn = LinearSelfAttention ( embed_dim = embed_dim , attn_drop = attn_drop , proj_drop = drop )
self . drop_path1 = DropPath ( drop_path )
self . norm2 = norm_layer ( embed_dim )
self . mlp = ConvMlp (
in_features = embed_dim ,
hidden_features = int ( embed_dim * mlp_ratio ) ,
act_layer = act_layer ,
drop = drop )
self . drop_path2 = DropPath ( drop_path )
def forward ( self , x : torch . Tensor , x_prev : Optional [ torch . Tensor ] = None ) - > torch . Tensor :
if x_prev is None :
# self-attention
x = x + self . drop_path1 ( self . attn ( self . norm1 ( x ) ) )
else :
# cross-attention
res = x
x = self . norm1 ( x ) # norm
x = self . attn ( x , x_prev ) # attn
x = self . drop_path1 ( x ) + res # residual
# Feed forward network
x = x + self . drop_path2 ( self . mlp ( self . norm2 ( x ) ) )
return x
@register_notrace_module
class MobileVitV2Block ( nn . Module ) :
"""
This class defines the ` MobileViTv2 block < > ` _
"""
def __init__ (
self ,
in_chs : int ,
out_chs : Optional [ int ] = None ,
kernel_size : int = 3 ,
bottle_ratio : float = 1.0 ,
group_size : Optional [ int ] = 1 ,
dilation : Tuple [ int , int ] = ( 1 , 1 ) ,
mlp_ratio : float = 2.0 ,
transformer_dim : Optional [ int ] = None ,
transformer_depth : int = 2 ,
patch_size : int = 8 ,
attn_drop : float = 0. ,
drop : int = 0. ,
drop_path_rate : float = 0. ,
layers : LayerFn = None ,
transformer_norm_layer : Callable = GroupNorm1 ,
* * kwargs , # eat unused args
) :
super ( MobileVitV2Block , self ) . __init__ ( )
layers = layers or LayerFn ( )
groups = num_groups ( group_size , in_chs )
out_chs = out_chs or in_chs
transformer_dim = transformer_dim or make_divisible ( bottle_ratio * in_chs )
self . conv_kxk = layers . conv_norm_act (
in_chs , in_chs , kernel_size = kernel_size ,
stride = 1 , groups = groups , dilation = dilation [ 0 ] )
self . conv_1x1 = nn . Conv2d ( in_chs , transformer_dim , kernel_size = 1 , bias = False )
self . transformer = nn . Sequential ( * [
LinearTransformerBlock (
transformer_dim ,
mlp_ratio = mlp_ratio ,
attn_drop = attn_drop ,
drop = drop ,
drop_path = drop_path_rate ,
act_layer = layers . act ,
norm_layer = transformer_norm_layer
)
for _ in range ( transformer_depth )
] )
self . norm = transformer_norm_layer ( transformer_dim )
self . conv_proj = layers . conv_norm_act ( transformer_dim , out_chs , kernel_size = 1 , stride = 1 , apply_act = False )
self . patch_size = to_2tuple ( patch_size )
self . patch_area = self . patch_size [ 0 ] * self . patch_size [ 1 ]
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
B , C , H , W = x . shape
patch_h , patch_w = self . patch_size
new_h , new_w = math . ceil ( H / patch_h ) * patch_h , math . ceil ( W / patch_w ) * patch_w
num_patch_h , num_patch_w = new_h / / patch_h , new_w / / patch_w # n_h, n_w
num_patches = num_patch_h * num_patch_w # N
if new_h != H or new_w != W :
x = F . interpolate ( x , size = ( new_h , new_w ) , mode = " bilinear " , align_corners = True )
# Local representation
x = self . conv_kxk ( x )
x = self . conv_1x1 ( x )
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
C = x . shape [ 1 ]
x = x . reshape ( B , C , num_patch_h , patch_h , num_patch_w , patch_w ) . permute ( 0 , 1 , 3 , 5 , 2 , 4 )
x = x . reshape ( B , C , - 1 , num_patches )
# Global representations
x = self . transformer ( x )
x = self . norm ( x )
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
x = x . reshape ( B , C , patch_h , patch_w , num_patch_h , num_patch_w ) . permute ( 0 , 1 , 4 , 2 , 5 , 3 )
x = x . reshape ( B , C , num_patch_h * patch_h , num_patch_w * patch_w )
x = self . conv_proj ( x )
return x
register_block ( ' mobilevit ' , MobileVitBlock )
register_block ( ' mobilevit2 ' , MobileVitV2Block )
def _create_mobilevit ( variant , cfg_variant = None , pretrained = False , * * kwargs ) :
@ -252,6 +600,14 @@ def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
* * kwargs )
def _create_mobilevit2 ( variant , cfg_variant = None , pretrained = False , * * kwargs ) :
return build_model_with_cfg (
ByobNet , variant , pretrained ,
model_cfg = model_cfgs [ variant ] if not cfg_variant else model_cfgs [ cfg_variant ] ,
feature_cfg = dict ( flatten_sequential = True ) ,
* * kwargs )
@register_model
def mobilevit_xxs ( pretrained = False , * * kwargs ) :
return _create_mobilevit ( ' mobilevit_xxs ' , pretrained = pretrained , * * kwargs )
@ -270,3 +626,74 @@ def mobilevit_s(pretrained=False, **kwargs):
@register_model
def semobilevit_s ( pretrained = False , * * kwargs ) :
return _create_mobilevit ( ' semobilevit_s ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_050 ( pretrained = False , * * kwargs ) :
return _create_mobilevit ( ' mobilevitv2_050 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_075 ( pretrained = False , * * kwargs ) :
return _create_mobilevit ( ' mobilevitv2_075 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_100 ( pretrained = False , * * kwargs ) :
return _create_mobilevit ( ' mobilevitv2_100 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_125 ( pretrained = False , * * kwargs ) :
return _create_mobilevit ( ' mobilevitv2_125 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_150 ( pretrained = False , * * kwargs ) :
return _create_mobilevit ( ' mobilevitv2_150 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_175 ( pretrained = False , * * kwargs ) :
return _create_mobilevit ( ' mobilevitv2_175 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_200 ( pretrained = False , * * kwargs ) :
return _create_mobilevit ( ' mobilevitv2_200 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_150_in22ft1k ( pretrained = False , * * kwargs ) :
return _create_mobilevit (
' mobilevitv2_150_in22ft1k ' , cfg_variant = ' mobilevitv2_150 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_175_in22ft1k ( pretrained = False , * * kwargs ) :
return _create_mobilevit (
' mobilevitv2_175_in22ft1k ' , cfg_variant = ' mobilevitv2_175 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_200_in22ft1k ( pretrained = False , * * kwargs ) :
return _create_mobilevit (
' mobilevitv2_200_in22ft1k ' , cfg_variant = ' mobilevitv2_200 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_150_384_in22ft1k ( pretrained = False , * * kwargs ) :
return _create_mobilevit (
' mobilevitv2_150_384_in22ft1k ' , cfg_variant = ' mobilevitv2_150 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_175_384_in22ft1k ( pretrained = False , * * kwargs ) :
return _create_mobilevit (
' mobilevitv2_175_384_in22ft1k ' , cfg_variant = ' mobilevitv2_175 ' , pretrained = pretrained , * * kwargs )
@register_model
def mobilevitv2_200_384_in22ft1k ( pretrained = False , * * kwargs ) :
return _create_mobilevit (
' mobilevitv2_200_384_in22ft1k ' , cfg_variant = ' mobilevitv2_200 ' , pretrained = pretrained , * * kwargs )