Move aggregation (convpool) for nest into NestLevel, cleanup and enable features_only use. Finalize weight url.

pull/738/head
Ross Wightman 4 years ago
parent 6ae0ac6420
commit 81cd6863c8

@ -23,6 +23,9 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New ## What's New
### July 5, 2021
* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare).
### June 23, 2021 ### June 23, 2021
* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6) * Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6)

@ -79,18 +79,18 @@ def convert_nest(checkpoint_path, arch):
state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.bias'] = torch.tensor( state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.bias'] = torch.tensor(
flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['bias']) flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['bias'])
# Block aggregations # Block aggregations (ConvPool)
for level in range(len(depths)-1): for level in range(1, len(depths)):
# Convs # Convs
state_dict[f'block_aggs.{level}.conv.weight'] = torch.tensor( state_dict[f'levels.{level}.pool.conv.weight'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['Conv_0']['kernel']).permute(3, 2, 0, 1) flax_dict[f'ConvPool_{level-1}']['Conv_0']['kernel']).permute(3, 2, 0, 1)
state_dict[f'block_aggs.{level}.conv.bias'] = torch.tensor( state_dict[f'levels.{level}.pool.conv.bias'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['Conv_0']['bias']) flax_dict[f'ConvPool_{level-1}']['Conv_0']['bias'])
# Norms # Norms
state_dict[f'block_aggs.{level}.norm.weight'] = torch.tensor( state_dict[f'levels.{level}.pool.norm.weight'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['scale']) flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['scale'])
state_dict[f'block_aggs.{level}.norm.bias'] = torch.tensor( state_dict[f'levels.{level}.pool.norm.bias'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['bias']) flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['bias'])
# Final norm # Final norm
state_dict[f'norm.weight'] = torch.tensor(flax_dict['LayerNorm_0']['scale']) state_dict[f'norm.weight'] = torch.tensor(flax_dict['LayerNorm_0']['scale'])
@ -105,5 +105,5 @@ def convert_nest(checkpoint_path, arch):
if __name__ == '__main__': if __name__ == '__main__':
variant = sys.argv[1] # base, small, or tiny variant = sys.argv[1] # base, small, or tiny
state_dict = convert_nest(f'../nested-transformer/checkpoints/nest-{variant[0]}_imagenet', f'nest_{variant}') state_dict = convert_nest(f'./nest-{variant[0]}_imagenet', f'nest_{variant}')
torch.save(state_dict, f'/home/alexander/.cache/torch/hub/checkpoints/jx_nest_{variant}.pth') torch.save(state_dict, f'./jx_nest_{variant}.pth')

@ -16,24 +16,19 @@ Copyright 2021 Alexander Soare
""" """
import collections.abc import collections.abc
from functools import partial
import math
import logging import logging
import math
from functools import partial
import numpy as np
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
from .layers.helpers import to_ntuple from .layers import create_conv2d, create_pool2d, to_ntuple
from .layers.create_conv2d import create_conv2d
from .layers.pool2d_same import create_pool2d
from .vision_transformer import Block
from .registry import register_model from .registry import register_model
from .helpers import build_model_with_cfg, named_apply
from .vision_transformer import resize_pos_embed
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -54,9 +49,12 @@ default_cfgs = {
'nest_base': _cfg(), 'nest_base': _cfg(),
'nest_small': _cfg(), 'nest_small': _cfg(),
'nest_tiny': _cfg(), 'nest_tiny': _cfg(),
'jx_nest_base': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_base.pth'), # TODO 'jx_nest_base': _cfg(
'jx_nest_small': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_small.pth'), # TODO url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth'),
'jx_nest_tiny': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_tiny.pth'), # TODO 'jx_nest_small': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth'),
'jx_nest_tiny': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth'),
} }
@ -96,7 +94,7 @@ class Attention(nn.Module):
return x # (B, T, N, C) return x # (B, T, N, C)
class TransformerLayer(Block): class TransformerLayer(nn.Module):
""" """
This is much like `.vision_transformer.Block` but: This is much like `.vision_transformer.Block` but:
- Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks") - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
@ -104,8 +102,7 @@ class TransformerLayer(Block):
""" """
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm): act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__(dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., super().__init__()
act_layer=nn.GELU, norm_layer=nn.LayerNorm)
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -120,7 +117,7 @@ class TransformerLayer(Block):
return x return x
class BlockAggregation(nn.Module): class ConvPool(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer, pad_type=''): def __init__(self, in_channels, out_channels, norm_layer, pad_type=''):
super().__init__() super().__init__()
self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True) self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True)
@ -152,8 +149,7 @@ def blockify(x, block_size: int):
grid_height = H // block_size grid_height = H // block_size
grid_width = W // block_size grid_width = W // block_size
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
x = x.permute(0, 1, 3, 2, 4, 5) x = x.transpose(2, 3).reshape(B, grid_height * grid_width, -1, C)
x = x.reshape(B, grid_height * grid_width, -1, C)
return x # (B, T, N, C) return x # (B, T, N, C)
@ -163,23 +159,30 @@ def deblockify(x, block_size: int):
x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block
block_size (int): edge length of a single square block in units of desired H, W block_size (int): edge length of a single square block in units of desired H, W
""" """
B, T, _, C= x.shape B, T, _, C = x.shape
grid_size = int(math.sqrt(T)) grid_size = int(math.sqrt(T))
x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
x = x.permute(0, 1, 3, 2, 4, 5)
height = width = grid_size * block_size height = width = grid_size * block_size
x = x.reshape(B, height, width, C) x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
x = x.transpose(2, 3).reshape(B, height, width, C)
return x # (B, H, W, C) return x # (B, H, W, C)
class NestLevel(nn.Module): class NestLevel(nn.Module):
""" Single hierarchical level of a Nested Transformer """ Single hierarchical level of a Nested Transformer
""" """
def __init__(self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, mlp_ratio=4., qkv_bias=True, def __init__(
drop_rate=0., attn_drop_rate=0., drop_path_rates=[], norm_layer=None, act_layer=None): self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, prev_embed_dim=None,
mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rates=[],
norm_layer=None, act_layer=None, pad_type=''):
super().__init__() super().__init__()
self.block_size = block_size self.block_size = block_size
self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim))
if prev_embed_dim is not None:
self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type)
else:
self.pool = nn.Identity()
# Transformer encoder # Transformer encoder
if len(drop_path_rates): if len(drop_path_rates):
assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers' assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers'
@ -194,15 +197,14 @@ class NestLevel(nn.Module):
""" """
expects x as (B, C, H, W) expects x as (B, C, H, W)
""" """
# Switch to channels last for transformer x = self.pool(x)
x = x.permute(0, 2, 3, 1) # (B, H', W', C) x = x.permute(0, 2, 3, 1) # (B, H', W', C), switch to channels last for transformer
x = blockify(x, self.block_size) # (B, T, N, C') x = blockify(x, self.block_size) # (B, T, N, C')
x = x + self.pos_embed x = x + self.pos_embed
x = self.transformer_encoder(x) # (B, T, N, C') x = self.transformer_encoder(x) # (B, T, N, C')
x = deblockify(x, self.block_size) # (B, H', W', C') x = deblockify(x, self.block_size) # (B, H', W', C')
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
x = x.permute(0, 3, 1, 2) # (B, C, H', W') return x.permute(0, 3, 1, 2) # (B, C, H', W')
return x
class Nest(nn.Module): class Nest(nn.Module):
@ -213,9 +215,9 @@ class Nest(nn.Module):
""" """
def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512), def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512),
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, pad_type='', num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, weight_init='', drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None,
global_pool='avg'): pad_type='', weight_init='', global_pool='avg'):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -233,6 +235,7 @@ class Nest(nn.Module):
drop_path_rate (float): stochastic depth rate drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer for transformer layers norm_layer: (nn.Module): normalization layer for transformer layers
act_layer: (nn.Module): activation layer in MLP of transformer layers act_layer: (nn.Module): activation layer in MLP of transformer layers
pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME
weight_init: (str): weight init scheme weight_init: (str): weight init scheme
global_pool: (str): type of pooling operation to apply to final feature map global_pool: (str): type of pooling operation to apply to final feature map
@ -254,6 +257,7 @@ class Nest(nn.Module):
depths = to_ntuple(num_levels)(depths) depths = to_ntuple(num_levels)(depths)
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = embed_dims[-1] self.num_features = embed_dims[-1]
self.feature_info = []
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU act_layer = act_layer or nn.GELU
self.drop_rate = drop_rate self.drop_rate = drop_rate
@ -265,60 +269,54 @@ class Nest(nn.Module):
self.patch_size = patch_size self.patch_size = patch_size
# Number of blocks at each level # Number of blocks at each level
self.num_blocks = 4**(np.arange(num_levels)[::-1]) self.num_blocks = (4 ** torch.arange(num_levels)).flip(0).tolist()
assert (img_size // patch_size) % np.sqrt(self.num_blocks[0]) == 0, \ assert (img_size // patch_size) % math.sqrt(self.num_blocks[0]) == 0, \
'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`' 'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`'
# Block edge size in units of patches # Block edge size in units of patches
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
# number of blocks along edge of image # number of blocks along edge of image
self.block_size = int((img_size // patch_size) // np.sqrt(self.num_blocks[0])) self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0]))
# Patch embedding # Patch embedding
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], flatten=False)
self.num_patches = self.patch_embed.num_patches self.num_patches = self.patch_embed.num_patches
self.seq_length = self.num_patches // self.num_blocks[0] self.seq_length = self.num_patches // self.num_blocks[0]
# Build up each hierarchical level # Build up each hierarchical level
self.levels = nn.ModuleList([]) levels = []
self.block_aggs = nn.ModuleList([]) dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] prev_dim = None
for lix in range(self.num_levels): curr_stride = 4
dpr = drop_path_rates[sum(depths[:lix]):sum(depths[:lix+1])] for i in range(len(self.num_blocks)):
self.levels.append(NestLevel( dim = embed_dims[i]
self.num_blocks[lix], self.block_size, self.seq_length, num_heads[lix], depths[lix], levels.append(NestLevel(
embed_dims[lix], mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dpr, norm_layer, self.num_blocks[i], self.block_size, self.seq_length, num_heads[i], depths[i], dim, prev_dim,
act_layer)) mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dp_rates[i], norm_layer, act_layer, pad_type=pad_type))
if lix < self.num_levels - 1: self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')]
self.block_aggs.append(BlockAggregation( prev_dim = dim
embed_dims[lix], embed_dims[lix+1], norm_layer, pad_type=pad_type)) curr_stride *= 2
else: self.levels = nn.Sequential(*levels)
# Required for zipped iteration over levels and ls_block_agg together
self.block_aggs.append(nn.Identity())
# Final normalization layer # Final normalization layer
self.norm = norm_layer(embed_dims[-1]) self.norm = norm_layer(embed_dims[-1])
# Classifier # Classifier
self.global_pool, self.head = create_classifier( self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.num_features, self.num_classes, pool_type=global_pool)
self.init_weights(weight_init) self.init_weights(weight_init)
def init_weights(self, mode=''): def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', '') assert mode in ('nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
for level in self.levels: for level in self.levels:
trunc_normal_(level.pos_embed, std=.02, a=-2, b=2) trunc_normal_(level.pos_embed, std=.02, a=-2, b=2)
if mode.startswith('jax'): named_apply(partial(_init_nest_weights, head_bias=head_bias), self)
named_apply(partial(_init_nest_weights, head_bias=head_bias, jax_impl=True), self)
else:
named_apply(_init_nest_weights, self)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {'pos_embed'} return {f'level.{i}.pos_embed' for i in range(len(self.levels))}
def get_classifier(self): def get_classifier(self):
return self.head return self.head
@ -333,13 +331,8 @@ class Nest(nn.Module):
""" """
B, _, H, W = x.shape B, _, H, W = x.shape
x = self.patch_embed(x) x = self.patch_embed(x)
x = x.reshape(B, H//self.patch_size, W//self.patch_size, -1) # (B, H', W', C') x = self.levels(x)
x = x.permute(0, 3, 1, 2) # Layer norm done over channel dim only (to NHWC and back)
# NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead
for level, block_agg in zip(self.levels, self.block_aggs):
x = level(x)
x = block_agg(x)
# Layer norm done over channel dim only
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x return x
@ -353,22 +346,19 @@ class Nest(nn.Module):
return self.head(x) return self.head(x)
def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.):
""" NesT weight initialization """ NesT weight initialization
Can replicate Jax implementation. Otherwise follows vision_transformer.py Can replicate Jax implementation. Otherwise follows vision_transformer.py
""" """
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
if name.startswith('head'): if name.startswith('head'):
if jax_impl:
trunc_normal_(module.weight, std=.02, a=-2, b=2) trunc_normal_(module.weight, std=.02, a=-2, b=2)
else:
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias) nn.init.constant_(module.bias, head_bias)
else: else:
trunc_normal_(module.weight, std=.02, a=-2, b=2) trunc_normal_(module.weight, std=.02, a=-2, b=2)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
elif jax_impl and isinstance(module, nn.Conv2d): elif isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02, a=-2, b=2) trunc_normal_(module.weight, std=.02, a=-2, b=2)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
@ -404,13 +394,11 @@ def checkpoint_filter_fn(state_dict, model):
def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs): def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
default_cfg = default_cfg or default_cfgs[variant] default_cfg = default_cfg or default_cfgs[variant]
model = build_model_with_cfg( model = build_model_with_cfg(
Nest, variant, pretrained, Nest, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfg,
feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True),
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
@ -422,7 +410,7 @@ def nest_base(pretrained=False, **kwargs):
""" Nest-B @ 224x224 """ Nest-B @ 224x224
""" """
model_kwargs = dict( model_kwargs = dict(
embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), drop_path_rate=0.5, **kwargs) embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
model = _create_nest('nest_base', pretrained=pretrained, **model_kwargs) model = _create_nest('nest_base', pretrained=pretrained, **model_kwargs)
return model return model
@ -431,8 +419,7 @@ def nest_base(pretrained=False, **kwargs):
def nest_small(pretrained=False, **kwargs): def nest_small(pretrained=False, **kwargs):
""" Nest-S @ 224x224 """ Nest-S @ 224x224
""" """
model_kwargs = dict( model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs)
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), drop_path_rate=0.3, **kwargs)
model = _create_nest('nest_small', pretrained=pretrained, **model_kwargs) model = _create_nest('nest_small', pretrained=pretrained, **model_kwargs)
return model return model
@ -441,8 +428,7 @@ def nest_small(pretrained=False, **kwargs):
def nest_tiny(pretrained=False, **kwargs): def nest_tiny(pretrained=False, **kwargs):
""" Nest-T @ 224x224 """ Nest-T @ 224x224
""" """
model_kwargs = dict( model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs)
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs)
model = _create_nest('nest_tiny', pretrained=pretrained, **model_kwargs) model = _create_nest('nest_tiny', pretrained=pretrained, **model_kwargs)
return model return model
@ -452,9 +438,7 @@ def jx_nest_base(pretrained=False, **kwargs):
""" Nest-B @ 224x224, Pretrained weights converted from official Jax impl. """ Nest-B @ 224x224, Pretrained weights converted from official Jax impl.
""" """
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
kwargs['weight_init'] = 'jax' model_kwargs = dict(embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
model_kwargs = dict(
embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), drop_path_rate=0.5, **kwargs)
model = _create_nest('jx_nest_base', pretrained=pretrained, **model_kwargs) model = _create_nest('jx_nest_base', pretrained=pretrained, **model_kwargs)
return model return model
@ -464,9 +448,7 @@ def jx_nest_small(pretrained=False, **kwargs):
""" Nest-S @ 224x224, Pretrained weights converted from official Jax impl. """ Nest-S @ 224x224, Pretrained weights converted from official Jax impl.
""" """
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
kwargs['weight_init'] = 'jax' model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs)
model_kwargs = dict(
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), drop_path_rate=0.3, **kwargs)
model = _create_nest('jx_nest_small', pretrained=pretrained, **model_kwargs) model = _create_nest('jx_nest_small', pretrained=pretrained, **model_kwargs)
return model return model
@ -476,8 +458,6 @@ def jx_nest_tiny(pretrained=False, **kwargs):
""" Nest-T @ 224x224, Pretrained weights converted from official Jax impl. """ Nest-T @ 224x224, Pretrained weights converted from official Jax impl.
""" """
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
kwargs['weight_init'] = 'jax' model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs)
model_kwargs = dict(
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs)
model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs) model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs)
return model return model
Loading…
Cancel
Save