@ -16,24 +16,19 @@ Copyright 2021 Alexander Soare
"""
"""
import collections . abc
import collections . abc
from functools import partial
import math
import logging
import logging
import math
from functools import partial
import numpy as np
import torch
import torch
from torch import nn
import torch . nn . functional as F
import torch . nn . functional as F
from torch import nn
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from . helpers import build_model_with_cfg , named_apply
from . layers import PatchEmbed , Mlp , DropPath , create_classifier , trunc_normal_
from . layers import PatchEmbed , Mlp , DropPath , create_classifier , trunc_normal_
from . layers . helpers import to_ntuple
from . layers import create_conv2d , create_pool2d , to_ntuple
from . layers . create_conv2d import create_conv2d
from . layers . pool2d_same import create_pool2d
from . vision_transformer import Block
from . registry import register_model
from . registry import register_model
from . helpers import build_model_with_cfg , named_apply
from . vision_transformer import resize_pos_embed
_logger = logging . getLogger ( __name__ )
_logger = logging . getLogger ( __name__ )
@ -54,9 +49,12 @@ default_cfgs = {
' nest_base ' : _cfg ( ) ,
' nest_base ' : _cfg ( ) ,
' nest_small ' : _cfg ( ) ,
' nest_small ' : _cfg ( ) ,
' nest_tiny ' : _cfg ( ) ,
' nest_tiny ' : _cfg ( ) ,
' jx_nest_base ' : _cfg ( url = ' https://www.todo-this-is-a-placeholder.com/jx_nest_base.pth ' ) , # TODO
' jx_nest_base ' : _cfg (
' jx_nest_small ' : _cfg ( url = ' https://www.todo-this-is-a-placeholder.com/jx_nest_small.pth ' ) , # TODO
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth ' ) ,
' jx_nest_tiny ' : _cfg ( url = ' https://www.todo-this-is-a-placeholder.com/jx_nest_tiny.pth ' ) , # TODO
' jx_nest_small ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth ' ) ,
' jx_nest_tiny ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth ' ) ,
}
}
@ -96,7 +94,7 @@ class Attention(nn.Module):
return x # (B, T, N, C)
return x # (B, T, N, C)
class TransformerLayer ( Block ) :
class TransformerLayer ( nn. Module ) :
"""
"""
This is much like ` . vision_transformer . Block ` but :
This is much like ` . vision_transformer . Block ` but :
- Called TransformerLayer here to allow for " block " as defined in the paper ( " non-overlapping image blocks " )
- Called TransformerLayer here to allow for " block " as defined in the paper ( " non-overlapping image blocks " )
@ -104,8 +102,7 @@ class TransformerLayer(Block):
"""
"""
def __init__ ( self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. , drop_path = 0. ,
def __init__ ( self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. , drop_path = 0. ,
act_layer = nn . GELU , norm_layer = nn . LayerNorm ) :
act_layer = nn . GELU , norm_layer = nn . LayerNorm ) :
super ( ) . __init__ ( dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. , drop_path = 0. ,
super ( ) . __init__ ( )
act_layer = nn . GELU , norm_layer = nn . LayerNorm )
self . norm1 = norm_layer ( dim )
self . norm1 = norm_layer ( dim )
self . attn = Attention ( dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )
self . attn = Attention ( dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )
self . drop_path = DropPath ( drop_path ) if drop_path > 0. else nn . Identity ( )
self . drop_path = DropPath ( drop_path ) if drop_path > 0. else nn . Identity ( )
@ -120,7 +117,7 @@ class TransformerLayer(Block):
return x
return x
class BlockAggregation ( nn . Module ) :
class ConvPool ( nn . Module ) :
def __init__ ( self , in_channels , out_channels , norm_layer , pad_type = ' ' ) :
def __init__ ( self , in_channels , out_channels , norm_layer , pad_type = ' ' ) :
super ( ) . __init__ ( )
super ( ) . __init__ ( )
self . conv = create_conv2d ( in_channels , out_channels , kernel_size = 3 , padding = pad_type , bias = True )
self . conv = create_conv2d ( in_channels , out_channels , kernel_size = 3 , padding = pad_type , bias = True )
@ -152,8 +149,7 @@ def blockify(x, block_size: int):
grid_height = H / / block_size
grid_height = H / / block_size
grid_width = W / / block_size
grid_width = W / / block_size
x = x . reshape ( B , grid_height , block_size , grid_width , block_size , C )
x = x . reshape ( B , grid_height , block_size , grid_width , block_size , C )
x = x . permute ( 0 , 1 , 3 , 2 , 4 , 5 )
x = x . transpose ( 2 , 3 ) . reshape ( B , grid_height * grid_width , - 1 , C )
x = x . reshape ( B , grid_height * grid_width , - 1 , C )
return x # (B, T, N, C)
return x # (B, T, N, C)
@ -163,23 +159,30 @@ def deblockify(x, block_size: int):
x ( Tensor ) : with shape ( B , T , N , C ) where T is number of blocks and N is sequence size per block
x ( Tensor ) : with shape ( B , T , N , C ) where T is number of blocks and N is sequence size per block
block_size ( int ) : edge length of a single square block in units of desired H , W
block_size ( int ) : edge length of a single square block in units of desired H , W
"""
"""
B , T , _ , C = x . shape
B , T , _ , C = x . shape
grid_size = int ( math . sqrt ( T ) )
grid_size = int ( math . sqrt ( T ) )
x = x . reshape ( B , grid_size , grid_size , block_size , block_size , C )
x = x . permute ( 0 , 1 , 3 , 2 , 4 , 5 )
height = width = grid_size * block_size
height = width = grid_size * block_size
x = x . reshape ( B , height , width , C )
x = x . reshape ( B , grid_size , grid_size , block_size , block_size , C )
x = x . transpose ( 2 , 3 ) . reshape ( B , height , width , C )
return x # (B, H, W, C)
return x # (B, H, W, C)
class NestLevel ( nn . Module ) :
class NestLevel ( nn . Module ) :
""" Single hierarchical level of a Nested Transformer
""" Single hierarchical level of a Nested Transformer
"""
"""
def __init__ ( self , num_blocks , block_size , seq_length , num_heads , depth , embed_dim , mlp_ratio = 4. , qkv_bias = True ,
def __init__ (
drop_rate = 0. , attn_drop_rate = 0. , drop_path_rates = [ ] , norm_layer = None , act_layer = None ) :
self , num_blocks , block_size , seq_length , num_heads , depth , embed_dim , prev_embed_dim = None ,
mlp_ratio = 4. , qkv_bias = True , drop_rate = 0. , attn_drop_rate = 0. , drop_path_rates = [ ] ,
norm_layer = None , act_layer = None , pad_type = ' ' ) :
super ( ) . __init__ ( )
super ( ) . __init__ ( )
self . block_size = block_size
self . block_size = block_size
self . pos_embed = nn . Parameter ( torch . zeros ( 1 , num_blocks , seq_length , embed_dim ) )
self . pos_embed = nn . Parameter ( torch . zeros ( 1 , num_blocks , seq_length , embed_dim ) )
if prev_embed_dim is not None :
self . pool = ConvPool ( prev_embed_dim , embed_dim , norm_layer = norm_layer , pad_type = pad_type )
else :
self . pool = nn . Identity ( )
# Transformer encoder
# Transformer encoder
if len ( drop_path_rates ) :
if len ( drop_path_rates ) :
assert len ( drop_path_rates ) == depth , ' Must provide as many drop path rates as there are transformer layers '
assert len ( drop_path_rates ) == depth , ' Must provide as many drop path rates as there are transformer layers '
@ -194,15 +197,14 @@ class NestLevel(nn.Module):
"""
"""
expects x as ( B , C , H , W )
expects x as ( B , C , H , W )
"""
"""
# Switch to channels last for transformer
x = self . pool ( x )
x = x . permute ( 0 , 2 , 3 , 1 ) # (B, H', W', C)
x = x . permute ( 0 , 2 , 3 , 1 ) # (B, H', W', C) , switch to channels last for transformer
x = blockify ( x , self . block_size ) # (B, T, N, C')
x = blockify ( x , self . block_size ) # (B, T, N, C')
x = x + self . pos_embed
x = x + self . pos_embed
x = self . transformer_encoder ( x ) # (B, T, N, C')
x = self . transformer_encoder ( x ) # (B, T, N, C')
x = deblockify ( x , self . block_size ) # (B, H', W', C')
x = deblockify ( x , self . block_size ) # (B, H', W', C')
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
x = x . permute ( 0 , 3 , 1 , 2 ) # (B, C, H', W')
return x . permute ( 0 , 3 , 1 , 2 ) # (B, C, H', W')
return x
class Nest ( nn . Module ) :
class Nest ( nn . Module ) :
@ -213,9 +215,9 @@ class Nest(nn.Module):
"""
"""
def __init__ ( self , img_size = 224 , in_chans = 3 , patch_size = 4 , num_levels = 3 , embed_dims = ( 128 , 256 , 512 ) ,
def __init__ ( self , img_size = 224 , in_chans = 3 , patch_size = 4 , num_levels = 3 , embed_dims = ( 128 , 256 , 512 ) ,
num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , num_classes = 1000 , mlp_ratio = 4. , qkv_bias = True , pad_type = ' ' ,
num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , num_classes = 1000 , mlp_ratio = 4. , qkv_bias = True ,
drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0.5 , norm_layer = None , act_layer = None , weight_init = ' ' ,
drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0.5 , norm_layer = None , act_layer = None ,
global_pool= ' avg ' ) :
pad_type= ' ' , weight_init = ' ' , global_pool= ' avg ' ) :
"""
"""
Args :
Args :
img_size ( int , tuple ) : input image size
img_size ( int , tuple ) : input image size
@ -233,6 +235,7 @@ class Nest(nn.Module):
drop_path_rate ( float ) : stochastic depth rate
drop_path_rate ( float ) : stochastic depth rate
norm_layer : ( nn . Module ) : normalization layer for transformer layers
norm_layer : ( nn . Module ) : normalization layer for transformer layers
act_layer : ( nn . Module ) : activation layer in MLP of transformer layers
act_layer : ( nn . Module ) : activation layer in MLP of transformer layers
pad_type : str : Type of padding to use ' ' for PyTorch symmetric , ' same ' for TF SAME
weight_init : ( str ) : weight init scheme
weight_init : ( str ) : weight init scheme
global_pool : ( str ) : type of pooling operation to apply to final feature map
global_pool : ( str ) : type of pooling operation to apply to final feature map
@ -254,6 +257,7 @@ class Nest(nn.Module):
depths = to_ntuple ( num_levels ) ( depths )
depths = to_ntuple ( num_levels ) ( depths )
self . num_classes = num_classes
self . num_classes = num_classes
self . num_features = embed_dims [ - 1 ]
self . num_features = embed_dims [ - 1 ]
self . feature_info = [ ]
norm_layer = norm_layer or partial ( nn . LayerNorm , eps = 1e-6 )
norm_layer = norm_layer or partial ( nn . LayerNorm , eps = 1e-6 )
act_layer = act_layer or nn . GELU
act_layer = act_layer or nn . GELU
self . drop_rate = drop_rate
self . drop_rate = drop_rate
@ -265,60 +269,54 @@ class Nest(nn.Module):
self . patch_size = patch_size
self . patch_size = patch_size
# Number of blocks at each level
# Number of blocks at each level
self . num_blocks = 4 * * ( np . arange ( num_levels ) [ : : - 1 ] )
self . num_blocks = ( 4 * * torch . arange ( num_levels ) ) . flip ( 0 ) . tolist ( )
assert ( img_size / / patch_size ) % np . sqrt ( self . num_blocks [ 0 ] ) == 0 , \
assert ( img_size / / patch_size ) % math . sqrt ( self . num_blocks [ 0 ] ) == 0 , \
' First level blocks don \' t fit evenly. Check `img_size`, `patch_size`, and `num_levels` '
' First level blocks don \' t fit evenly. Check `img_size`, `patch_size`, and `num_levels` '
# Block edge size in units of patches
# Block edge size in units of patches
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
# number of blocks along edge of image
# number of blocks along edge of image
self . block_size = int ( ( img_size / / patch_size ) / / np . sqrt ( self . num_blocks [ 0 ] ) )
self . block_size = int ( ( img_size / / patch_size ) / / math . sqrt ( self . num_blocks [ 0 ] ) )
# Patch embedding
# Patch embedding
self . patch_embed = PatchEmbed (
self . patch_embed = PatchEmbed (
img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dims [ 0 ] )
img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dims [ 0 ] , flatten = False )
self . num_patches = self . patch_embed . num_patches
self . num_patches = self . patch_embed . num_patches
self . seq_length = self . num_patches / / self . num_blocks [ 0 ]
self . seq_length = self . num_patches / / self . num_blocks [ 0 ]
# Build up each hierarchical level
# Build up each hierarchical level
self . levels = nn . ModuleList ( [ ] )
levels = [ ]
self . block_aggs = nn . ModuleList ( [ ] )
dp_rates = [ x . tolist ( ) for x in torch . linspace ( 0 , drop_path_rate , sum ( depths ) ) . split ( depths ) ]
drop_path_rates = [ x . item ( ) for x in torch . linspace ( 0 , drop_path_rate , sum ( depths ) ) ]
prev_dim = None
for lix in range ( self . num_levels ) :
curr_stride = 4
dpr = drop_path_rates [ sum ( depths [ : lix ] ) : sum ( depths [ : lix + 1 ] ) ]
for i in range ( len ( self . num_blocks ) ) :
self . levels . append ( NestLevel (
dim = embed_dims [ i ]
self . num_blocks [ lix ] , self . block_size , self . seq_length , num_heads [ lix ] , depths [ lix ] ,
levels . append ( NestLevel (
embed_dims [ lix ] , mlp_ratio , qkv_bias , drop_rate , attn_drop_rate , dpr , norm_layer ,
self . num_blocks [ i ] , self . block_size , self . seq_length , num_heads [ i ] , depths [ i ] , dim , prev_dim ,
act_layer ) )
mlp_ratio , qkv_bias , drop_rate , attn_drop_rate , dp_rates [ i ] , norm_layer , act_layer , pad_type = pad_type ) )
if lix < self . num_levels - 1 :
self . feature_info + = [ dict ( num_chs = dim , reduction = curr_stride , module = f ' levels. { i } ' ) ]
self . block_aggs . append ( BlockAggregation (
prev_dim = dim
embed_dims [ lix ] , embed_dims [ lix + 1 ] , norm_layer , pad_type = pad_type ) )
curr_stride * = 2
else :
self . levels = nn . Sequential ( * levels )
# Required for zipped iteration over levels and ls_block_agg together
self . block_aggs . append ( nn . Identity ( ) )
# Final normalization layer
# Final normalization layer
self . norm = norm_layer ( embed_dims [ - 1 ] )
self . norm = norm_layer ( embed_dims [ - 1 ] )
# Classifier
# Classifier
self . global_pool , self . head = create_classifier (
self . global_pool , self . head = create_classifier ( self . num_features , self . num_classes , pool_type = global_pool )
self . num_features , self . num_classes , pool_type = global_pool )
self . init_weights ( weight_init )
self . init_weights ( weight_init )
def init_weights ( self , mode = ' ' ) :
def init_weights ( self , mode = ' ' ) :
assert mode in ( ' jax' , ' jax_nlhb ' , ' nlhb' , ' ' )
assert mode in ( ' nlhb' , ' ' )
head_bias = - math . log ( self . num_classes ) if ' nlhb ' in mode else 0.
head_bias = - math . log ( self . num_classes ) if ' nlhb ' in mode else 0.
for level in self . levels :
for level in self . levels :
trunc_normal_ ( level . pos_embed , std = .02 , a = - 2 , b = 2 )
trunc_normal_ ( level . pos_embed , std = .02 , a = - 2 , b = 2 )
if mode . startswith ( ' jax ' ) :
named_apply ( partial ( _init_nest_weights , head_bias = head_bias ) , self )
named_apply ( partial ( _init_nest_weights , head_bias = head_bias , jax_impl = True ) , self )
else :
named_apply ( _init_nest_weights , self )
@torch.jit.ignore
@torch.jit.ignore
def no_weight_decay ( self ) :
def no_weight_decay ( self ) :
return { ' pos_embed' }
return { f ' level.{ i } . pos_embed' for i in range ( len ( self . levels ) ) }
def get_classifier ( self ) :
def get_classifier ( self ) :
return self . head
return self . head
@ -333,13 +331,8 @@ class Nest(nn.Module):
"""
"""
B , _ , H , W = x . shape
B , _ , H , W = x . shape
x = self . patch_embed ( x )
x = self . patch_embed ( x )
x = x . reshape ( B , H / / self . patch_size , W / / self . patch_size , - 1 ) # (B, H', W', C')
x = self . levels ( x )
x = x . permute ( 0 , 3 , 1 , 2 )
# Layer norm done over channel dim only (to NHWC and back)
# NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead
for level , block_agg in zip ( self . levels , self . block_aggs ) :
x = level ( x )
x = block_agg ( x )
# Layer norm done over channel dim only
x = self . norm ( x . permute ( 0 , 2 , 3 , 1 ) ) . permute ( 0 , 3 , 1 , 2 )
x = self . norm ( x . permute ( 0 , 2 , 3 , 1 ) ) . permute ( 0 , 3 , 1 , 2 )
return x
return x
@ -353,22 +346,19 @@ class Nest(nn.Module):
return self . head ( x )
return self . head ( x )
def _init_nest_weights ( module : nn . Module , name : str = ' ' , head_bias : float = 0. , jax_impl : bool = False ):
def _init_nest_weights ( module : nn . Module , name : str = ' ' , head_bias : float = 0. ):
""" NesT weight initialization
""" NesT weight initialization
Can replicate Jax implementation . Otherwise follows vision_transformer . py
Can replicate Jax implementation . Otherwise follows vision_transformer . py
"""
"""
if isinstance ( module , nn . Linear ) :
if isinstance ( module , nn . Linear ) :
if name . startswith ( ' head ' ) :
if name . startswith ( ' head ' ) :
if jax_impl :
trunc_normal_ ( module . weight , std = .02 , a = - 2 , b = 2 )
trunc_normal_ ( module . weight , std = .02 , a = - 2 , b = 2 )
else :
nn . init . zeros_ ( module . weight )
nn . init . constant_ ( module . bias , head_bias )
nn . init . constant_ ( module . bias , head_bias )
else :
else :
trunc_normal_ ( module . weight , std = .02 , a = - 2 , b = 2 )
trunc_normal_ ( module . weight , std = .02 , a = - 2 , b = 2 )
if module . bias is not None :
if module . bias is not None :
nn . init . zeros_ ( module . bias )
nn . init . zeros_ ( module . bias )
elif jax_impl and isinstance ( module , nn . Conv2d ) :
elif isinstance ( module , nn . Conv2d ) :
trunc_normal_ ( module . weight , std = .02 , a = - 2 , b = 2 )
trunc_normal_ ( module . weight , std = .02 , a = - 2 , b = 2 )
if module . bias is not None :
if module . bias is not None :
nn . init . zeros_ ( module . bias )
nn . init . zeros_ ( module . bias )
@ -404,13 +394,11 @@ def checkpoint_filter_fn(state_dict, model):
def _create_nest ( variant , pretrained = False , default_cfg = None , * * kwargs ) :
def _create_nest ( variant , pretrained = False , default_cfg = None , * * kwargs ) :
if kwargs . get ( ' features_only ' , None ) :
raise RuntimeError ( ' features_only not implemented for Vision Transformer models. ' )
default_cfg = default_cfg or default_cfgs [ variant ]
default_cfg = default_cfg or default_cfgs [ variant ]
model = build_model_with_cfg (
model = build_model_with_cfg (
Nest , variant , pretrained ,
Nest , variant , pretrained ,
default_cfg = default_cfg ,
default_cfg = default_cfg ,
feature_cfg = dict ( out_indices = ( 0 , 1 , 2 ) , flatten_sequential = True ) ,
pretrained_filter_fn = checkpoint_filter_fn ,
pretrained_filter_fn = checkpoint_filter_fn ,
* * kwargs )
* * kwargs )
@ -422,7 +410,7 @@ def nest_base(pretrained=False, **kwargs):
""" Nest-B @ 224x224
""" Nest-B @ 224x224
"""
"""
model_kwargs = dict (
model_kwargs = dict (
embed_dims = ( 128 , 256 , 512 ) , num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , drop_path_rate = 0.5 , * * kwargs )
embed_dims = ( 128 , 256 , 512 ) , num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , * * kwargs )
model = _create_nest ( ' nest_base ' , pretrained = pretrained , * * model_kwargs )
model = _create_nest ( ' nest_base ' , pretrained = pretrained , * * model_kwargs )
return model
return model
@ -431,8 +419,7 @@ def nest_base(pretrained=False, **kwargs):
def nest_small ( pretrained = False , * * kwargs ) :
def nest_small ( pretrained = False , * * kwargs ) :
""" Nest-S @ 224x224
""" Nest-S @ 224x224
"""
"""
model_kwargs = dict (
model_kwargs = dict ( embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 20 ) , * * kwargs )
embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 20 ) , drop_path_rate = 0.3 , * * kwargs )
model = _create_nest ( ' nest_small ' , pretrained = pretrained , * * model_kwargs )
model = _create_nest ( ' nest_small ' , pretrained = pretrained , * * model_kwargs )
return model
return model
@ -441,8 +428,7 @@ def nest_small(pretrained=False, **kwargs):
def nest_tiny ( pretrained = False , * * kwargs ) :
def nest_tiny ( pretrained = False , * * kwargs ) :
""" Nest-T @ 224x224
""" Nest-T @ 224x224
"""
"""
model_kwargs = dict (
model_kwargs = dict ( embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 8 ) , * * kwargs )
embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 8 ) , drop_path_rate = 0.2 , * * kwargs )
model = _create_nest ( ' nest_tiny ' , pretrained = pretrained , * * model_kwargs )
model = _create_nest ( ' nest_tiny ' , pretrained = pretrained , * * model_kwargs )
return model
return model
@ -452,9 +438,7 @@ def jx_nest_base(pretrained=False, **kwargs):
""" Nest-B @ 224x224, Pretrained weights converted from official Jax impl.
""" Nest-B @ 224x224, Pretrained weights converted from official Jax impl.
"""
"""
kwargs [ ' pad_type ' ] = ' same '
kwargs [ ' pad_type ' ] = ' same '
kwargs [ ' weight_init ' ] = ' jax '
model_kwargs = dict ( embed_dims = ( 128 , 256 , 512 ) , num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , * * kwargs )
model_kwargs = dict (
embed_dims = ( 128 , 256 , 512 ) , num_heads = ( 4 , 8 , 16 ) , depths = ( 2 , 2 , 20 ) , drop_path_rate = 0.5 , * * kwargs )
model = _create_nest ( ' jx_nest_base ' , pretrained = pretrained , * * model_kwargs )
model = _create_nest ( ' jx_nest_base ' , pretrained = pretrained , * * model_kwargs )
return model
return model
@ -464,9 +448,7 @@ def jx_nest_small(pretrained=False, **kwargs):
""" Nest-S @ 224x224, Pretrained weights converted from official Jax impl.
""" Nest-S @ 224x224, Pretrained weights converted from official Jax impl.
"""
"""
kwargs [ ' pad_type ' ] = ' same '
kwargs [ ' pad_type ' ] = ' same '
kwargs [ ' weight_init ' ] = ' jax '
model_kwargs = dict ( embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 20 ) , * * kwargs )
model_kwargs = dict (
embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 20 ) , drop_path_rate = 0.3 , * * kwargs )
model = _create_nest ( ' jx_nest_small ' , pretrained = pretrained , * * model_kwargs )
model = _create_nest ( ' jx_nest_small ' , pretrained = pretrained , * * model_kwargs )
return model
return model
@ -476,8 +458,6 @@ def jx_nest_tiny(pretrained=False, **kwargs):
""" Nest-T @ 224x224, Pretrained weights converted from official Jax impl.
""" Nest-T @ 224x224, Pretrained weights converted from official Jax impl.
"""
"""
kwargs [ ' pad_type ' ] = ' same '
kwargs [ ' pad_type ' ] = ' same '
kwargs [ ' weight_init ' ] = ' jax '
model_kwargs = dict ( embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 8 ) , * * kwargs )
model_kwargs = dict (
embed_dims = ( 96 , 192 , 384 ) , num_heads = ( 3 , 6 , 12 ) , depths = ( 2 , 2 , 8 ) , drop_path_rate = 0.2 , * * kwargs )
model = _create_nest ( ' jx_nest_tiny ' , pretrained = pretrained , * * model_kwargs )
model = _create_nest ( ' jx_nest_tiny ' , pretrained = pretrained , * * model_kwargs )
return model
return model