@ -16,24 +16,19 @@ Copyright 2021 Alexander Soare
"""
import collections . abc
from functools import partial
import math
import logging
import math
from functools import partial
import numpy as np
import torch
from torch import nn
import torch . nn . functional as F
from torch import nn
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from . helpers import build_model_with_cfg , named_apply
from . layers import PatchEmbed , Mlp , DropPath , create_classifier , trunc_normal_
from . layers . helpers import to_ntuple
from . layers . create_conv2d import create_conv2d
from . layers . pool2d_same import create_pool2d
from . vision_transformer import Block
from . layers import create_conv2d , create_pool2d , to_ntuple
from . registry import register_model
from . helpers import build_model_with_cfg , named_apply
from . vision_transformer import resize_pos_embed
_logger = logging . getLogger ( __name__ )
@ -54,9 +49,12 @@ default_cfgs = {
' nest_base ' : _cfg ( ) ,
' nest_small ' : _cfg ( ) ,
' nest_tiny ' : _cfg ( ) ,
' jx_nest_base ' : _cfg ( url = ' https://www.todo-this-is-a-placeholder.com/jx_nest_base.pth ' ) , # TODO
' jx_nest_small ' : _cfg ( url = ' https://www.todo-this-is-a-placeholder.com/jx_nest_small.pth ' ) , # TODO
' jx_nest_tiny ' : _cfg ( url = ' https://www.todo-this-is-a-placeholder.com/jx_nest_tiny.pth ' ) , # TODO
' jx_nest_base ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth ' ) ,
' jx_nest_small ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth ' ) ,
' jx_nest_tiny ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth ' ) ,
}
@ -96,7 +94,7 @@ class Attention(nn.Module):
return x # (B, T, N, C)
class TransformerLayer ( Block ) :
class TransformerLayer ( nn. Module ) :
"""
This is much like ` . vision_transformer . Block ` but :
- Called TransformerLayer here to allow for " block " as defined in the paper ( " non-overlapping image blocks " )
@ -104,8 +102,7 @@ class TransformerLayer(Block):
"""
def __init__ ( self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. , drop_path = 0. ,
act_layer = nn . GELU , norm_layer = nn . LayerNorm ) :
super ( ) . __init__ ( dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. , drop_path = 0. ,
act_layer = nn . GELU , norm_layer = nn . LayerNorm )
super ( ) . __init__ ( )
self . norm1 = norm_layer ( dim )
self . attn = Attention ( dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )
self . drop_path = DropPath ( drop_path ) if drop_path > 0. else nn . Identity ( )
@ -120,7 +117,7 @@ class TransformerLayer(Block):
return x
class BlockAggregation ( nn . Module ) :
class ConvPool ( nn . Module ) :
def __init__ ( self , in_channels , out_channels , norm_layer , pad_type = ' ' ) :
super ( ) . __init__ ( )
self . conv = create_conv2d ( in_channels , out_channels , kernel_size = 3 , padding = pad_type , bias = True )
@ -152,8 +149,7 @@ def blockify(x, block_size: int):
grid_height = H / / block_size
grid_width = W / / block_size
x = x . reshape ( B , grid_height , block_size , grid_width , block_size , C )
x = x . permute ( 0 , 1 , 3 , 2 , 4 , 5 )
x = x . reshape ( B , grid_height * grid_width , - 1 , C )
x = x . transpose ( 2 , 3 ) . reshape ( B , grid_height * grid_width , - 1 , C )
return x # (B, T, N, C)
@ -165,21 +161,28 @@ def deblockify(x, block_size: int):
"""
B , T , _ , C = x . shape
grid_size = int ( math . sqrt ( T ) )
x = x . reshape ( B , grid_size , grid_size , block_size , block_size , C )
x = x . permute ( 0 , 1 , 3 , 2 , 4 , 5 )
height = width = grid_size * block_size
x = x . reshape ( B , height , width , C )
x = x . reshape ( B , grid_size , grid_size , block_size , block_size , C )
x = x . transpose ( 2 , 3 ) . reshape ( B , height , width , C )
return x # (B, H, W, C)
class NestLevel ( nn . Module ) :
""" Single hierarchical level of a Nested Transformer
"""
def __init__ ( self , num_blocks , block_size , seq_length , num_heads , depth , embed_dim , mlp_ratio = 4. , qkv_bias = True ,
drop_rate = 0. , attn_drop_rate = 0. , drop_path_rates = [ ] , norm_layer = None , act_layer = None ) :
def __init__ (
self , num_blocks , block_size , seq_length , num_heads , depth , embed_dim , prev_embed_dim = None ,
mlp_ratio = 4. , qkv_bias = True , drop_rate = 0. , attn_drop_rate = 0. , drop_path_rates = [ ] ,
norm_layer = None , act_layer = None , pad_type = ' ' ) :
super ( ) . __init__ ( )
self . block_size = block_size
self . pos_embed = nn . Parameter ( torch . zeros ( 1 , num_blocks , seq_length , embed_dim ) )
if prev_embed_dim is not None :
self . pool = ConvPool ( prev_embed_dim , embed_dim , norm_layer = norm_layer , pad_type = pad_type )
else :
self . pool = nn . Identity ( )
# Transformer encoder
if len ( drop_path_rates ) :
assert len ( drop_path_rates ) == depth , ' Must provide as many drop path rates as there are transformer layers '
@ -194,15 +197,14 @@ class NestLevel(nn.Module):
"""
expects x as ( B , C , H , W )
"""
# Switch to channels last for transformer
x = x . permute ( 0 , 2 , 3 , 1 ) # (B, H', W', C)
x = self . pool ( x )
x = x . permute ( 0 , 2 , 3 , 1 ) # (B, H', W', C) , switch to channels last for transformer
x = blockify ( x , self . block_size ) # (B, T, N, C')
x = x + self . pos_embed
x = self . transformer_encoder ( x ) # (B, T, N, C')
x = deblockify ( x , self . block_size ) # (B, H', W', C')
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
x = x . permute ( 0 , 3 , 1 , 2 ) # (B, C, H', W')
return x
return x . permute ( 0 , 3 , 1 , 2 ) # (B, C, H', W')
class Nest ( nn . Module ) :
@ -213,9 +215,9 @@ class Nest(nn.Module):
"""
def __init__ ( self , img_size = 224 , in_chans = 3 , patch_size = 4 , num_levels = 3 , embed_dims = ( 128 , 256 , 512 ) ,
num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , num_classes = 1000 , mlp_ratio = 4. , qkv_bias = True , pad_type = ' ' ,
drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0.5 , norm_layer = None , act_layer = None , weight_init = ' ' ,
global_pool= ' avg ' ) :
num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , num_classes = 1000 , mlp_ratio = 4. , qkv_bias = True ,
drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0.5 , norm_layer = None , act_layer = None ,
pad_type= ' ' , weight_init = ' ' , global_pool= ' avg ' ) :
"""
Args :
img_size ( int , tuple ) : input image size
@ -233,6 +235,7 @@ class Nest(nn.Module):
drop_path_rate ( float ) : stochastic depth rate
norm_layer : ( nn . Module ) : normalization layer for transformer layers
act_layer : ( nn . Module ) : activation layer in MLP of transformer layers
pad_type : str : Type of padding to use ' ' for PyTorch symmetric , ' same ' for TF SAME
weight_init : ( str ) : weight init scheme
global_pool : ( str ) : type of pooling operation to apply to final feature map
@ -254,6 +257,7 @@ class Nest(nn.Module):
depths = to_ntuple ( num_levels ) ( depths )
self . num_classes = num_classes
self . num_features = embed_dims [ - 1 ]
self . feature_info = [ ]
norm_layer = norm_layer or partial ( nn . LayerNorm , eps = 1e-6 )
act_layer = act_layer or nn . GELU
self . drop_rate = drop_rate
@ -265,60 +269,54 @@ class Nest(nn.Module):
self . patch_size = patch_size
# Number of blocks at each level
self . num_blocks = 4 * * ( np . arange ( num_levels ) [ : : - 1 ] )
assert ( img_size / / patch_size ) % np . sqrt ( self . num_blocks [ 0 ] ) == 0 , \
self . num_blocks = ( 4 * * torch . arange ( num_levels ) ) . flip ( 0 ) . tolist ( )
assert ( img_size / / patch_size ) % math . sqrt ( self . num_blocks [ 0 ] ) == 0 , \
' First level blocks don \' t fit evenly. Check `img_size`, `patch_size`, and `num_levels` '
# Block edge size in units of patches
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
# number of blocks along edge of image
self . block_size = int ( ( img_size / / patch_size ) / / np . sqrt ( self . num_blocks [ 0 ] ) )
self . block_size = int ( ( img_size / / patch_size ) / / math . sqrt ( self . num_blocks [ 0 ] ) )
# Patch embedding
self . patch_embed = PatchEmbed (
img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dims [ 0 ] )
img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dims [ 0 ] , flatten = False )
self . num_patches = self . patch_embed . num_patches
self . seq_length = self . num_patches / / self . num_blocks [ 0 ]
# Build up each hierarchical level
self . levels = nn . ModuleList ( [ ] )
self . block_aggs = nn . ModuleList ( [ ] )
drop_path_rates = [ x . item ( ) for x in torch . linspace ( 0 , drop_path_rate , sum ( depths ) ) ]
for lix in range ( self . num_levels ) :
dpr = drop_path_rates [ sum ( depths [ : lix ] ) : sum ( depths [ : lix + 1 ] ) ]
self . levels . append ( NestLevel (
self . num_blocks [ lix ] , self . block_size , self . seq_length , num_heads [ lix ] , depths [ lix ] ,
embed_dims [ lix ] , mlp_ratio , qkv_bias , drop_rate , attn_drop_rate , dpr , norm_layer ,
act_layer ) )
if lix < self . num_levels - 1 :
self . block_aggs . append ( BlockAggregation (
embed_dims [ lix ] , embed_dims [ lix + 1 ] , norm_layer , pad_type = pad_type ) )
else :
# Required for zipped iteration over levels and ls_block_agg together
self . block_aggs . append ( nn . Identity ( ) )
levels = [ ]
dp_rates = [ x . tolist ( ) for x in torch . linspace ( 0 , drop_path_rate , sum ( depths ) ) . split ( depths ) ]
prev_dim = None
curr_stride = 4
for i in range ( len ( self . num_blocks ) ) :
dim = embed_dims [ i ]
levels . append ( NestLevel (
self . num_blocks [ i ] , self . block_size , self . seq_length , num_heads [ i ] , depths [ i ] , dim , prev_dim ,
mlp_ratio , qkv_bias , drop_rate , attn_drop_rate , dp_rates [ i ] , norm_layer , act_layer , pad_type = pad_type ) )
self . feature_info + = [ dict ( num_chs = dim , reduction = curr_stride , module = f ' levels. { i } ' ) ]
prev_dim = dim
curr_stride * = 2
self . levels = nn . Sequential ( * levels )
# Final normalization layer
self . norm = norm_layer ( embed_dims [ - 1 ] )
# Classifier
self . global_pool , self . head = create_classifier (
self . num_features , self . num_classes , pool_type = global_pool )
self . global_pool , self . head = create_classifier ( self . num_features , self . num_classes , pool_type = global_pool )
self . init_weights ( weight_init )
def init_weights ( self , mode = ' ' ) :
assert mode in ( ' jax' , ' jax_nlhb ' , ' nlhb' , ' ' )
assert mode in ( ' nlhb' , ' ' )
head_bias = - math . log ( self . num_classes ) if ' nlhb ' in mode else 0.
for level in self . levels :
trunc_normal_ ( level . pos_embed , std = .02 , a = - 2 , b = 2 )
if mode . startswith ( ' jax ' ) :
named_apply ( partial ( _init_nest_weights , head_bias = head_bias , jax_impl = True ) , self )
else :
named_apply ( _init_nest_weights , self )
named_apply ( partial ( _init_nest_weights , head_bias = head_bias ) , self )
@torch.jit.ignore
def no_weight_decay ( self ) :
return { ' pos_embed' }
return { f ' level.{ i } . pos_embed' for i in range ( len ( self . levels ) ) }
def get_classifier ( self ) :
return self . head
@ -333,13 +331,8 @@ class Nest(nn.Module):
"""
B , _ , H , W = x . shape
x = self . patch_embed ( x )
x = x . reshape ( B , H / / self . patch_size , W / / self . patch_size , - 1 ) # (B, H', W', C')
x = x . permute ( 0 , 3 , 1 , 2 )
# NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead
for level , block_agg in zip ( self . levels , self . block_aggs ) :
x = level ( x )
x = block_agg ( x )
# Layer norm done over channel dim only
x = self . levels ( x )
# Layer norm done over channel dim only (to NHWC and back)
x = self . norm ( x . permute ( 0 , 2 , 3 , 1 ) ) . permute ( 0 , 3 , 1 , 2 )
return x
@ -353,22 +346,19 @@ class Nest(nn.Module):
return self . head ( x )
def _init_nest_weights ( module : nn . Module , name : str = ' ' , head_bias : float = 0. , jax_impl : bool = False ):
def _init_nest_weights ( module : nn . Module , name : str = ' ' , head_bias : float = 0. ):
""" NesT weight initialization
Can replicate Jax implementation . Otherwise follows vision_transformer . py
"""
if isinstance ( module , nn . Linear ) :
if name . startswith ( ' head ' ) :
if jax_impl :
trunc_normal_ ( module . weight , std = .02 , a = - 2 , b = 2 )
else :
nn . init . zeros_ ( module . weight )
nn . init . constant_ ( module . bias , head_bias )
else :
trunc_normal_ ( module . weight , std = .02 , a = - 2 , b = 2 )
if module . bias is not None :
nn . init . zeros_ ( module . bias )
elif jax_impl and isinstance ( module , nn . Conv2d ) :
elif isinstance ( module , nn . Conv2d ) :
trunc_normal_ ( module . weight , std = .02 , a = - 2 , b = 2 )
if module . bias is not None :
nn . init . zeros_ ( module . bias )
@ -404,13 +394,11 @@ def checkpoint_filter_fn(state_dict, model):
def _create_nest ( variant , pretrained = False , default_cfg = None , * * kwargs ) :
if kwargs . get ( ' features_only ' , None ) :
raise RuntimeError ( ' features_only not implemented for Vision Transformer models. ' )
default_cfg = default_cfg or default_cfgs [ variant ]
model = build_model_with_cfg (
Nest , variant , pretrained ,
default_cfg = default_cfg ,
feature_cfg = dict ( out_indices = ( 0 , 1 , 2 ) , flatten_sequential = True ) ,
pretrained_filter_fn = checkpoint_filter_fn ,
* * kwargs )
@ -422,7 +410,7 @@ def nest_base(pretrained=False, **kwargs):
""" Nest-B @ 224x224
"""
model_kwargs = dict (
embed_dims = ( 128 , 256 , 512 ) , num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , drop_path_rate = 0.5 , * * kwargs )
embed_dims = ( 128 , 256 , 512 ) , num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , * * kwargs )
model = _create_nest ( ' nest_base ' , pretrained = pretrained , * * model_kwargs )
return model
@ -431,8 +419,7 @@ def nest_base(pretrained=False, **kwargs):
def nest_small ( pretrained = False , * * kwargs ) :
""" Nest-S @ 224x224
"""
model_kwargs = dict (
embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 20 ) , drop_path_rate = 0.3 , * * kwargs )
model_kwargs = dict ( embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 20 ) , * * kwargs )
model = _create_nest ( ' nest_small ' , pretrained = pretrained , * * model_kwargs )
return model
@ -441,8 +428,7 @@ def nest_small(pretrained=False, **kwargs):
def nest_tiny ( pretrained = False , * * kwargs ) :
""" Nest-T @ 224x224
"""
model_kwargs = dict (
embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 8 ) , drop_path_rate = 0.2 , * * kwargs )
model_kwargs = dict ( embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 8 ) , * * kwargs )
model = _create_nest ( ' nest_tiny ' , pretrained = pretrained , * * model_kwargs )
return model
@ -452,9 +438,7 @@ def jx_nest_base(pretrained=False, **kwargs):
""" Nest-B @ 224x224, Pretrained weights converted from official Jax impl.
"""
kwargs [ ' pad_type ' ] = ' same '
kwargs [ ' weight_init ' ] = ' jax '
model_kwargs = dict (
embed_dims = ( 128 , 256 , 512 ) , num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , drop_path_rate = 0.5 , * * kwargs )
model_kwargs = dict ( embed_dims = ( 128 , 256 , 512 ) , num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , * * kwargs )
model = _create_nest ( ' jx_nest_base ' , pretrained = pretrained , * * model_kwargs )
return model
@ -464,9 +448,7 @@ def jx_nest_small(pretrained=False, **kwargs):
""" Nest-S @ 224x224, Pretrained weights converted from official Jax impl.
"""
kwargs [ ' pad_type ' ] = ' same '
kwargs [ ' weight_init ' ] = ' jax '
model_kwargs = dict (
embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 20 ) , drop_path_rate = 0.3 , * * kwargs )
model_kwargs = dict ( embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 20 ) , * * kwargs )
model = _create_nest ( ' jx_nest_small ' , pretrained = pretrained , * * model_kwargs )
return model
@ -476,8 +458,6 @@ def jx_nest_tiny(pretrained=False, **kwargs):
""" Nest-T @ 224x224, Pretrained weights converted from official Jax impl.
"""
kwargs [ ' pad_type ' ] = ' same '
kwargs [ ' weight_init ' ] = ' jax '
model_kwargs = dict (
embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 8 ) , drop_path_rate = 0.2 , * * kwargs )
model_kwargs = dict ( embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 8 ) , * * kwargs )
model = _create_nest ( ' jx_nest_tiny ' , pretrained = pretrained , * * model_kwargs )
return model