From 23bb72ce5e7012147105df3835e9954f8fe540ae Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 24 Jun 2021 21:02:13 +0100 Subject: [PATCH 1/3] nested_transformer wip --- convert/convert_nest_flax.py | 107 ++++++++ timm/models/__init__.py | 1 + timm/models/nest.py | 480 +++++++++++++++++++++++++++++++++++ 3 files changed, 588 insertions(+) create mode 100644 convert/convert_nest_flax.py create mode 100644 timm/models/nest.py diff --git a/convert/convert_nest_flax.py b/convert/convert_nest_flax.py new file mode 100644 index 00000000..469b7e2f --- /dev/null +++ b/convert/convert_nest_flax.py @@ -0,0 +1,107 @@ +""" +Convert weights from https://github.com/google-research/nested-transformer +NOTE: You'll need https://github.com/google/CommonLoopUtils, not included in requirements.txt +""" + +import numpy as np +import torch + +from clu import checkpoint + + +arch_depths = { + 'nest_base': [2, 2, 20], + 'nest_small': [2, 2, 20], + 'nest_tiny': [2, 2, 8], +} + + +def convert_nest(checkpoint_path, arch): + """ + Expects path to checkpoint which is a dir containing 4 files like in each of these folders + - https://console.cloud.google.com/storage/browser/gresearch/nest-checkpoints + `arch` is needed to + Returns a state dict that can be used with `torch.nn.Module.load_state_dict` + Hint: Follow timm.models.nest.Nest.__init__ and + https://github.com/google-research/nested-transformer/blob/main/models/nest_net.py + """ + assert arch in ['nest_base', 'nest_small', 'nest_tiny'], "Your `arch` is not supported" + + flax_dict = checkpoint.load_state_dict(checkpoint_path)['optimizer']['target'] + state_dict = {} + + # Patch embedding + state_dict['patch_embed.proj.weight'] = torch.tensor( + flax_dict['PatchEmbedding_0']['Conv_0']['kernel']).permute(3, 2, 0, 1) + state_dict['patch_embed.proj.bias'] = torch.tensor(flax_dict['PatchEmbedding_0']['Conv_0']['bias']) + + # Positional embeddings + posemb_keys = [k for k in flax_dict.keys() if k.startswith('PositionEmbedding')] + for i, k in enumerate(posemb_keys): + state_dict[f'pos_embed_{i}'] = torch.tensor(flax_dict[k]['pos_embedding']) + + # Transformer encoders + depths = arch_depths[arch] + for level in range(len(depths)): + for layer in range(depths[level]): + global_layer_ix = sum(depths[:level]) + layer + # Norms + for i in range(2): + state_dict[f'ls_transformer_encoder.{level}.{layer}.norm{i+1}.weight'] = torch.tensor( + flax_dict[f'EncoderNDBlock_{global_layer_ix}'][f'LayerNorm_{i}']['scale']) + state_dict[f'ls_transformer_encoder.{level}.{layer}.norm{i+1}.bias'] = torch.tensor( + flax_dict[f'EncoderNDBlock_{global_layer_ix}'][f'LayerNorm_{i}']['bias']) + # Attention qkv + w_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_0']['kernel'] + w_kv = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_1']['kernel'] + # Pay attention to dims here (maybe get pen and paper) + w_kv = np.concatenate(np.split(w_kv, 2, -1), 1) + w_qkv = np.concatenate([w_q, w_kv], 1) + state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.qkv.weight'] = torch.tensor(w_qkv).flatten(1).permute(1,0) + b_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_0']['bias'] + b_kv = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_1']['bias'] + # Pay attention to dims here (maybe get pen and paper) + b_kv = np.concatenate(np.split(b_kv, 2, -1), 0) + b_qkv = np.concatenate([b_q, b_kv], 0) + state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.qkv.bias'] = torch.tensor(b_qkv).reshape(-1) + # Attention proj + w_proj = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['proj_kernel'] + w_proj = torch.tensor(w_proj).permute(2, 1, 0).flatten(1) + state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.proj.weight'] = w_proj + state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.proj.bias'] = torch.tensor( + flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['bias']) + # MLP + for i in range(2): + state_dict[f'ls_transformer_encoder.{level}.{layer}.mlp.fc{i+1}.weight'] = torch.tensor( + flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['kernel']).permute(1, 0) + state_dict[f'ls_transformer_encoder.{level}.{layer}.mlp.fc{i+1}.bias'] = torch.tensor( + flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['bias']) + + # Block aggregations + for level in range(len(depths)-1): + # Convs + state_dict[f'ls_block_aggregation.{level}.conv.weight'] = torch.tensor( + flax_dict[f'ConvPool_{level}']['Conv_0']['kernel']).permute(3, 2, 0, 1) + state_dict[f'ls_block_aggregation.{level}.conv.bias'] = torch.tensor( + flax_dict[f'ConvPool_{level}']['Conv_0']['bias']) + # Norms + state_dict[f'ls_block_aggregation.{level}.norm.weight'] = torch.tensor( + flax_dict[f'ConvPool_{level}']['LayerNorm_0']['scale']) + state_dict[f'ls_block_aggregation.{level}.norm.bias'] = torch.tensor( + flax_dict[f'ConvPool_{level}']['LayerNorm_0']['bias']) + + # Final norm + state_dict[f'norm.weight'] = torch.tensor(flax_dict['LayerNorm_0']['scale']) + state_dict[f'norm.bias'] = torch.tensor(flax_dict['LayerNorm_0']['bias']) + + # Classifier + state_dict['head.weight'] = torch.tensor(flax_dict['Dense_0']['kernel']).permute(1, 0) + state_dict['head.bias'] = torch.tensor(flax_dict['Dense_0']['bias']) + + return state_dict + + +if __name__ == '__main__': + variant = 'base' + state_dict = convert_nest(f'../nested-transformer/checkpoints/nest-{variant[0]}_imagenet', f'nest_{variant}') + torch.save(state_dict, f'jx_nest_{variant}.pth') \ No newline at end of file diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 06217e18..5a48e325 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -20,6 +20,7 @@ from .levit import * from .mlp_mixer import * from .mobilenetv3 import * from .nasnet import * +from .nest import * from .nfnet import * from .pit import * from .pnasnet import * diff --git a/timm/models/nest.py b/timm/models/nest.py new file mode 100644 index 00000000..dcf664c6 --- /dev/null +++ b/timm/models/nest.py @@ -0,0 +1,480 @@ +""" Nested Transformer (NesT) in PyTorch + +A PyTorch implement of Aggregating Nested Transformers as described in: + +'Aggregating Nested Transformers' + - https://arxiv.org/abs/2105.12723 + +The official Jax code is released and available at https://github.com/google-research/nested-transformer + +Acknowledgments: +* The paper authors for sharing their research, code, and model weights +* Ross Wightman's existing code off which I based this + +Copyright 2021 Alexander Soare +""" + +import collections.abc +from functools import partial +import math +import logging + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ +from .layers.helpers import to_ntuple +from .layers.create_conv2d import create_conv2d +from .layers.pool2d_same import create_pool2d +from .vision_transformer import Block +from .registry import register_model +from .helpers import build_model_with_cfg, named_apply +from .vision_transformer import resize_pos_embed + +_logger = logging.getLogger(__name__) + + +# TODO check first_conv. everything else has been checked +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': [14, 14], + 'crop_pct': .875, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # (weights from official Google JAX impl) + 'nest_base': _cfg(), + 'nest_small': _cfg(), + 'nest_tiny': _cfg(), + 'jx_nest_base': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_base.pth'), # TODO + 'jx_nest_small': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_small.pth'), # TODO + 'jx_nest_tiny': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_tiny.pth'), # TODO +} + + +# TODO - Leave note for Ross - Maybe we can generalize Attention to this and put it in layers +class Attention(nn.Module): + """ + This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with + an extra "image block" dim + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, 3*dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + """ + x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim) + """ + B, T, N, C = x.shape + # result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head) + qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # (B, H, T, N, C'), permute -> (B, T, N, C', H) + x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x # (B, T, N, C) + + +class TransformerLayer(Block): + """ + This is much like `.vision_transformer.Block` but: + - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks") + - Uses modified Attention layer that handles the "block" dimension + TODO somehow reuse the code instead of rewriting it... + """ + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__(dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm) + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + y = self.norm1(x) + x = x + self.drop_path(self.attn(y)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class BlockAggregation(nn.Module): + def __init__(self, in_channels, out_channels, norm_layer, pad_type=''): + super().__init__() + self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True) + self.norm = norm_layer(out_channels) + self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=pad_type) + + def forward(self, x): + """ + x is expected to have shape (B, C, H, W) + """ + assert x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims' + assert x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims' + x = self.conv(x) + # Layer norm done over channel dim only + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = self.pool(x) + return x # (B, C, H//2, W//2) + + +def blockify(x, block_size: int): + """image to blocks + Args: + x (Tensor): with shape (B, H, W, C) + block_size (int): edge length of a single square block in units of H, W + """ + B, H, W, C = x.shape + assert H % block_size == 0, '`block_size` must divide input height evenly' + assert W % block_size == 0, '`block_size` must divide input width evenly' + grid_height = H // block_size + grid_width = W // block_size + x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) + x = x.permute(0, 1, 3, 2, 4, 5) + x = x.reshape(B, grid_height * grid_width, -1, C) + return x # (B, T, N, C) + + +def deblockify(x, block_size: int): + """blocks to image + Args: + x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block + block_size (int): edge length of a single square block in units of desired H, W + """ + B, T, _, C= x.shape + grid_size = int(math.sqrt(T)) + x = x.reshape(B, grid_size, grid_size, block_size, block_size, C) + x = x.permute(0, 1, 3, 2, 4, 5) + height = width = grid_size * block_size + x = x.reshape(B, height, width, C) + return x # (B, H, W, C) + + +class Nest(nn.Module): + """ Nested Transformer (NesT) + + A PyTorch impl of : `Aggregating Nested Transformers` + - https://arxiv.org/abs/2105.12723 + """ + + def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512), + num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., + qkv_bias=True, pad_type='', + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, + act_layer=None, weight_init='', global_pool='avg'): + """ + Args: + img_size (int, tuple): input image size + in_chans (int): number of input channels + patch_size (int): patch size + num_levels (int): number of block hierarchies (T_d in the paper) + embed_dims (int, tuple): embedding dimensions of each level + num_heads (int, tuple): number of attention heads for each level + depths (int, tuple): number of transformer layers for each level + num_classes (int): number of classes for classification head + mlp_ratio (int): ratio of mlp hidden dim to embedding dim for MLP of transformer layers + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate for MLP of transformer layers, MSA final projection layer, and classifier + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer for transformer layers + act_layer: (nn.Module): activation layer in MLP of transformer layers + weight_init: (str): weight init scheme TODO check + global_pool: (str): type of pooling operation to apply to final feature map + + Notes: + - Default values follow NesT-B from the original Jax code. + - `embed_dims`, `num_heads`, `depths` should be ints or tuples with length `num_levels`. + - For those following the paper, Table A1 may have errors! + - https://github.com/google-research/nested-transformer/issues/2 + """ + super().__init__() + + for param_name in ['embed_dims', 'num_heads', 'depths']: + param_value = locals()[param_name] + if isinstance(param_value, collections.abc.Sequence): + assert len(param_value) == num_levels, f'Require `len({param_name}) == num_levels`' + + embed_dims = to_ntuple(num_levels)(embed_dims) + num_heads = to_ntuple(num_levels)(num_heads) + depths = to_ntuple(num_levels)(depths) + self.num_classes = num_classes + self.num_features = embed_dims[-1] + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + self.drop_rate = drop_rate + self.num_levels = num_levels + if isinstance(img_size, collections.abc.Sequence): + assert img_size[0] == img_size[1], 'Model only handles square inputs' + img_size = img_size[0] + assert img_size % patch_size == 0, '`patch_size` must divide `img_size` evenly' + self.patch_size = patch_size + + # Number of blocks at each level + self.num_blocks = 4**(np.arange(num_levels)[::-1]) + assert (img_size // patch_size) % np.sqrt(self.num_blocks[0]) == 0, \ + 'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`' + + # Block edge size in units of patches + # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the + # number of blocks along edge of image + self.block_size = int((img_size // patch_size) // np.sqrt(self.num_blocks[0])) + + # Patch embedding + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) + self.num_patches = self.patch_embed.num_patches + + # Build up each hierarchical level + self.ls_pos_embed = [] + self.ls_transformer_encoder = nn.ModuleList([]) + self.ls_block_aggregation = nn.ModuleList([]) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # drop path rate + self.feature_info = [] + for level in range(self.num_levels): + # Positional embedding + # NOTE: Can't use ParameterList for positional embedding as it can't be enumerated with TorchScript + pos_embed = nn.Parameter( + torch.zeros(1, self.num_blocks[level], self.num_patches // self.num_blocks[0], embed_dims[level])) + self.register_parameter(f'pos_embed_{level}', pos_embed) + self.ls_pos_embed.append(pos_embed) + # Transformer encoder + self.ls_transformer_encoder.append(nn.Sequential(*[ + TransformerLayer( + dim=embed_dims[level], num_heads=num_heads[level], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:level]) + i], + norm_layer=norm_layer, act_layer=act_layer) + for i in range(depths[level])])) + + self.feature_info.append(dict( + num_chs=embed_dims[level], reduction=2, + module=f'ls_transformer_encoder.{level}.{depths[level]-1}.mlp.fc2')) + + # Block aggregation (not required for last level) + if level < self.num_levels - 1: + self.ls_block_aggregation.append( + BlockAggregation(embed_dims[level], embed_dims[level+1], norm_layer, pad_type=pad_type)) + else: + # NOTE: Required for enumeration over all level components at once + self.ls_block_aggregation.append(nn.Identity()) + self.ls_pos_embed = tuple(self.ls_pos_embed) # static length required for torchscript + + + # Final normalization layer + self.norm = norm_layer(embed_dims[-1]) + + # Classifier + self.global_pool, self.head = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + for pos_embed in self.ls_pos_embed: + trunc_normal_(pos_embed, std=.02, a=-2, b=2) + if mode.startswith('jax'): + named_apply(partial(_init_nest_weights, head_bias=head_bias, jax_impl=True), self) + else: + self.apply(_init_nest_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.head = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + """ x shape (B, C, H, W) + """ + B, _, H, W = x.shape + x = self.patch_embed(x) # (B, N, C) + x = x.reshape(B, H//self.patch_size, W//self.patch_size, -1) # (B, H', W', C') + # NOTE: TorchScript wants enumeration rather than subscripting of ModuleList + for level, (pos_embed, transformer, block_agg) in enumerate( + zip(self.ls_pos_embed, self.ls_transformer_encoder, self.ls_block_aggregation)): + if level > 0: + # Switch back to channels last for transformer + x = x.permute(0, 2, 3, 1) # (B, H', W', C) + x = blockify(x, self.block_size) # (B, T, N, C') + x = x + pos_embed + x = transformer(x) # (B, T, N, C') + x = deblockify(x, self.block_size) # (B, H', W', C') + # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage + x = x.permute(0, 3, 1, 2) # (B, C, H', W') + if level < self.num_levels - 1: + x = block_agg(x) # (B, C', H'//2, W'//2) + # Layer norm done over channel dim only + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + def forward(self, x): + """ x shape (B, C, H, W) + """ + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.head(x) + + +def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): + """ NesT weight initialization + Can replicate Jax implementation. Otherwise follows vision_transformer.py + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + if jax_impl: + trunc_normal_(module.weight, std=.02, a=-2, b=2) + else: + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + trunc_normal_(module.weight, std=.02, a=-2, b=2) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=.02, a=-2, b=2) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +def resize_pos_embed(posemb, posemb_new): + """ + Rescale the grid of position embeddings when loading from state_dict + Expected shape of position embeddings is (1, T, N, C), and considers only square images + """ + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + seq_length_old = posemb.shape[2] + num_blocks_new, seq_length_new = posemb_new.shape[1:3] + size_new = int(math.sqrt(num_blocks_new*seq_length_new)) + # First change to (1, C, H, W) + posemb = deblockify(posemb, int(math.sqrt(seq_length_old))).permute(0, 3, 1, 2) + posemb = F.interpolate(posemb, size=[size_new, size_new], mode='bilinear') + # Now change to new (1, T, N, C) + posemb = blockify(posemb.permute(0, 2, 3, 1), int(math.sqrt(seq_length_new))) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ resize positional embeddings of pretrained weights """ + pos_embed_keys = [k for k in state_dict.keys() if k.startswith('pos_embed_')] + for k in pos_embed_keys: + if state_dict[k].shape != getattr(model, k).shape: + state_dict[k] = resize_pos_embed(state_dict[k], getattr(model, k)) + return state_dict + + +def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs): + # if kwargs.get('features_only', None): + # raise RuntimeError('features_only not implemented for Vision Transformer models.') + + default_cfg = default_cfg or default_cfgs[variant] + + model = build_model_with_cfg( + Nest, variant, pretrained, + default_cfg=default_cfg, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + +@register_model +def nest_base(pretrained=False, **kwargs): + """ Nest-B @ 224x224 + """ + model_kwargs = dict( + embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), drop_path_rate=0.5, **kwargs) + model = _create_nest('nest_base', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def nest_small(pretrained=False, **kwargs): + """ Nest-S @ 224x224 + """ + model_kwargs = dict( + embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), drop_path_rate=0.3, **kwargs) + model = _create_nest('nest_small', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def nest_tiny(pretrained=False, **kwargs): + """ Nest-T @ 224x224 + """ + model_kwargs = dict( + embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs) + model = _create_nest('nest_tiny', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def jx_nest_base(pretrained=False, **kwargs): + """ Nest-B @ 224x224, Pretrained weights converted from official Jax impl. + """ + kwargs['pad_type'] = 'same' + kwargs['weight_init'] = 'jax' + model_kwargs = dict( + embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), drop_path_rate=0.5, **kwargs) + model = _create_nest('jx_nest_base', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def jx_nest_small(pretrained=False, **kwargs): + """ Nest-S @ 224x224, Pretrained weights converted from official Jax impl. + """ + kwargs['pad_type'] = 'same' + kwargs['weight_init'] = 'jax' + model_kwargs = dict( + embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), drop_path_rate=0.3, **kwargs) + model = _create_nest('jx_nest_small', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def jx_nest_tiny(pretrained=False, **kwargs): + """ Nest-T @ 224x224, Pretrained weights converted from official Jax impl. + """ + kwargs['pad_type'] = 'same' + kwargs['weight_init'] = 'jax' + model_kwargs = dict( + embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs) + model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs) + return model From b11d949a06f4c19be616aba25e49bf33d13a6b23 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 3 Jul 2021 11:45:19 +0100 Subject: [PATCH 2/3] wip checkpoint with some feature extraction work --- timm/models/nest.py | 135 +++++++++++++++++++++++++------------------- 1 file changed, 76 insertions(+), 59 deletions(-) diff --git a/timm/models/nest.py b/timm/models/nest.py index dcf664c6..0131b677 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -5,7 +5,8 @@ A PyTorch implement of Aggregating Nested Transformers as described in: 'Aggregating Nested Transformers' - https://arxiv.org/abs/2105.12723 -The official Jax code is released and available at https://github.com/google-research/nested-transformer +The official Jax code is released and available at https://github.com/google-research/nested-transformer. The weights +have been converted with convert/convert_nest_flax.py Acknowledgments: * The paper authors for sharing their research, code, and model weights @@ -37,7 +38,6 @@ from .vision_transformer import resize_pos_embed _logger = logging.getLogger(__name__) -# TODO check first_conv. everything else has been checked def _cfg(url='', **kwargs): return { 'url': url, @@ -60,7 +60,6 @@ default_cfgs = { } -# TODO - Leave note for Ross - Maybe we can generalize Attention to this and put it in layers class Attention(nn.Module): """ This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with @@ -102,7 +101,6 @@ class TransformerLayer(Block): This is much like `.vision_transformer.Block` but: - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks") - Uses modified Attention layer that handles the "block" dimension - TODO somehow reuse the code instead of rewriting it... """ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): @@ -172,6 +170,39 @@ def deblockify(x, block_size: int): height = width = grid_size * block_size x = x.reshape(B, height, width, C) return x # (B, H, W, C) + + +class NestLevel(nn.Module): + """ Single hierarchical level of a Nested Transformer + """ + def __init__(self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rates=[], norm_layer=None, act_layer=None): + super().__init__() + self.block_size = block_size + self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim)) + # Transformer encoder + if len(drop_path_rates): + assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers' + self.transformer_encoder = nn.Sequential(*[ + TransformerLayer( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rates[i], + norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + + def forward(self, x): + """ + expects x as (B, C, H, W) + """ + # Switch to channels last for transformer + x = x.permute(0, 2, 3, 1) # (B, H', W', C) + x = blockify(x, self.block_size) # (B, T, N, C') + x = x + self.pos_embed + x = self.transformer_encoder(x) # (B, T, N, C') + x = deblockify(x, self.block_size) # (B, H', W', C') + # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage + x = x.permute(0, 3, 1, 2) # (B, C, H', W') + return x class Nest(nn.Module): @@ -182,10 +213,9 @@ class Nest(nn.Module): """ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512), - num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., - qkv_bias=True, pad_type='', - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, - act_layer=None, weight_init='', global_pool='avg'): + num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, pad_type='', + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, weight_init='', + global_pool='avg'): """ Args: img_size (int, tuple): input image size @@ -203,7 +233,7 @@ class Nest(nn.Module): drop_path_rate (float): stochastic depth rate norm_layer: (nn.Module): normalization layer for transformer layers act_layer: (nn.Module): activation layer in MLP of transformer layers - weight_init: (str): weight init scheme TODO check + weight_init: (str): weight init scheme global_pool: (str): type of pooling operation to apply to final feature map Notes: @@ -247,45 +277,33 @@ class Nest(nn.Module): # Patch embedding self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) + self.feature_info = [dict(num_chs=embed_dims[0], reduction=patch_size, module='patch_embed')] self.num_patches = self.patch_embed.num_patches + self.seq_length = self.num_patches // self.num_blocks[0] # Build up each hierarchical level - self.ls_pos_embed = [] - self.ls_transformer_encoder = nn.ModuleList([]) - self.ls_block_aggregation = nn.ModuleList([]) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # drop path rate - self.feature_info = [] - for level in range(self.num_levels): - # Positional embedding - # NOTE: Can't use ParameterList for positional embedding as it can't be enumerated with TorchScript - pos_embed = nn.Parameter( - torch.zeros(1, self.num_blocks[level], self.num_patches // self.num_blocks[0], embed_dims[level])) - self.register_parameter(f'pos_embed_{level}', pos_embed) - self.ls_pos_embed.append(pos_embed) - # Transformer encoder - self.ls_transformer_encoder.append(nn.Sequential(*[ - TransformerLayer( - dim=embed_dims[level], num_heads=num_heads[level], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:level]) + i], - norm_layer=norm_layer, act_layer=act_layer) - for i in range(depths[level])])) - - self.feature_info.append(dict( - num_chs=embed_dims[level], reduction=2, - module=f'ls_transformer_encoder.{level}.{depths[level]-1}.mlp.fc2')) - - # Block aggregation (not required for last level) - if level < self.num_levels - 1: - self.ls_block_aggregation.append( - BlockAggregation(embed_dims[level], embed_dims[level+1], norm_layer, pad_type=pad_type)) + self.levels = nn.ModuleList([]) + self.block_aggs = nn.ModuleList([]) + drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + for lix in range(self.num_levels): + dpr = drop_path_rates[sum(depths[:lix]):sum(depths[:lix+1])] + self.levels.append(NestLevel( + self.num_blocks[lix], self.block_size, self.seq_length, num_heads[lix], depths[lix], + embed_dims[lix], mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dpr, norm_layer, + act_layer)) + self.feature_info.append( + dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction']*2, module=f'levels.{lix}')) + if lix < self.num_levels - 1: + self.block_aggs.append(BlockAggregation( + embed_dims[lix], embed_dims[lix+1], norm_layer, pad_type=pad_type)) else: - # NOTE: Required for enumeration over all level components at once - self.ls_block_aggregation.append(nn.Identity()) - self.ls_pos_embed = tuple(self.ls_pos_embed) # static length required for torchscript + # Required for zipped iteration over levels and ls_block_agg together + self.block_aggs.append(nn.Identity()) - # Final normalization layer self.norm = norm_layer(embed_dims[-1]) + self.feature_info.append( + dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction'], module='norm')) # Classifier self.global_pool, self.head = create_classifier( @@ -296,8 +314,8 @@ class Nest(nn.Module): def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'nlhb', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. - for pos_embed in self.ls_pos_embed: - trunc_normal_(pos_embed, std=.02, a=-2, b=2) + for level in self.levels: + trunc_normal_(level.pos_embed, std=.02, a=-2, b=2) if mode.startswith('jax'): named_apply(partial(_init_nest_weights, head_bias=head_bias, jax_impl=True), self) else: @@ -319,22 +337,13 @@ class Nest(nn.Module): """ x shape (B, C, H, W) """ B, _, H, W = x.shape - x = self.patch_embed(x) # (B, N, C) + x = self.patch_embed(x) x = x.reshape(B, H//self.patch_size, W//self.patch_size, -1) # (B, H', W', C') - # NOTE: TorchScript wants enumeration rather than subscripting of ModuleList - for level, (pos_embed, transformer, block_agg) in enumerate( - zip(self.ls_pos_embed, self.ls_transformer_encoder, self.ls_block_aggregation)): - if level > 0: - # Switch back to channels last for transformer - x = x.permute(0, 2, 3, 1) # (B, H', W', C) - x = blockify(x, self.block_size) # (B, T, N, C') - x = x + pos_embed - x = transformer(x) # (B, T, N, C') - x = deblockify(x, self.block_size) # (B, H', W', C') - # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage - x = x.permute(0, 3, 1, 2) # (B, C, H', W') - if level < self.num_levels - 1: - x = block_agg(x) # (B, C', H'//2, W'//2) + x = x.permute(0, 3, 1, 2) + # NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead + for level, block_agg in zip(self.levels, self.block_aggs): + x = level(x) + x = block_agg(x) # Layer norm done over channel dim only x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x @@ -404,11 +413,12 @@ def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs): # raise RuntimeError('features_only not implemented for Vision Transformer models.') default_cfg = default_cfg or default_cfgs[variant] - model = build_model_with_cfg( Nest, variant, pretrained, default_cfg=default_cfg, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict( + out_indices=tuple(range(kwargs.get('num_levels', 3) + 2)), feature_cls='hook', flatten_sequential=True), **kwargs) return model @@ -478,3 +488,10 @@ def jx_nest_tiny(pretrained=False, **kwargs): embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs) model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs) return model + + +if __name__ == '__main__': + model = jx_nest_base() + model = torch.jit.script(model) + inp = torch.zeros(8, 3, 224, 224) + print(model.forward_features(inp).shape) \ No newline at end of file From 7b8a0017f125e4a30bbc38906a1f64dc036a9884 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 3 Jul 2021 12:10:12 +0100 Subject: [PATCH 3/3] wip to review --- convert/convert_nest_flax.py | 32 +++++++++++++++++--------------- tests/test_models.py | 2 +- timm/models/nest.py | 22 ++++------------------ 3 files changed, 22 insertions(+), 34 deletions(-) diff --git a/convert/convert_nest_flax.py b/convert/convert_nest_flax.py index 469b7e2f..513aa803 100644 --- a/convert/convert_nest_flax.py +++ b/convert/convert_nest_flax.py @@ -3,6 +3,8 @@ Convert weights from https://github.com/google-research/nested-transformer NOTE: You'll need https://github.com/google/CommonLoopUtils, not included in requirements.txt """ +import sys + import numpy as np import torch @@ -38,7 +40,7 @@ def convert_nest(checkpoint_path, arch): # Positional embeddings posemb_keys = [k for k in flax_dict.keys() if k.startswith('PositionEmbedding')] for i, k in enumerate(posemb_keys): - state_dict[f'pos_embed_{i}'] = torch.tensor(flax_dict[k]['pos_embedding']) + state_dict[f'levels.{i}.pos_embed'] = torch.tensor(flax_dict[k]['pos_embedding']) # Transformer encoders depths = arch_depths[arch] @@ -47,9 +49,9 @@ def convert_nest(checkpoint_path, arch): global_layer_ix = sum(depths[:level]) + layer # Norms for i in range(2): - state_dict[f'ls_transformer_encoder.{level}.{layer}.norm{i+1}.weight'] = torch.tensor( + state_dict[f'levels.{level}.transformer_encoder.{layer}.norm{i+1}.weight'] = torch.tensor( flax_dict[f'EncoderNDBlock_{global_layer_ix}'][f'LayerNorm_{i}']['scale']) - state_dict[f'ls_transformer_encoder.{level}.{layer}.norm{i+1}.bias'] = torch.tensor( + state_dict[f'levels.{level}.transformer_encoder.{layer}.norm{i+1}.bias'] = torch.tensor( flax_dict[f'EncoderNDBlock_{global_layer_ix}'][f'LayerNorm_{i}']['bias']) # Attention qkv w_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_0']['kernel'] @@ -57,37 +59,37 @@ def convert_nest(checkpoint_path, arch): # Pay attention to dims here (maybe get pen and paper) w_kv = np.concatenate(np.split(w_kv, 2, -1), 1) w_qkv = np.concatenate([w_q, w_kv], 1) - state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.qkv.weight'] = torch.tensor(w_qkv).flatten(1).permute(1,0) + state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.qkv.weight'] = torch.tensor(w_qkv).flatten(1).permute(1,0) b_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_0']['bias'] b_kv = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_1']['bias'] # Pay attention to dims here (maybe get pen and paper) b_kv = np.concatenate(np.split(b_kv, 2, -1), 0) b_qkv = np.concatenate([b_q, b_kv], 0) - state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.qkv.bias'] = torch.tensor(b_qkv).reshape(-1) + state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.qkv.bias'] = torch.tensor(b_qkv).reshape(-1) # Attention proj w_proj = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['proj_kernel'] w_proj = torch.tensor(w_proj).permute(2, 1, 0).flatten(1) - state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.proj.weight'] = w_proj - state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.proj.bias'] = torch.tensor( + state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.proj.weight'] = w_proj + state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.proj.bias'] = torch.tensor( flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['bias']) # MLP for i in range(2): - state_dict[f'ls_transformer_encoder.{level}.{layer}.mlp.fc{i+1}.weight'] = torch.tensor( + state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.weight'] = torch.tensor( flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['kernel']).permute(1, 0) - state_dict[f'ls_transformer_encoder.{level}.{layer}.mlp.fc{i+1}.bias'] = torch.tensor( + state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.bias'] = torch.tensor( flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['bias']) # Block aggregations for level in range(len(depths)-1): # Convs - state_dict[f'ls_block_aggregation.{level}.conv.weight'] = torch.tensor( + state_dict[f'block_aggs.{level}.conv.weight'] = torch.tensor( flax_dict[f'ConvPool_{level}']['Conv_0']['kernel']).permute(3, 2, 0, 1) - state_dict[f'ls_block_aggregation.{level}.conv.bias'] = torch.tensor( + state_dict[f'block_aggs.{level}.conv.bias'] = torch.tensor( flax_dict[f'ConvPool_{level}']['Conv_0']['bias']) # Norms - state_dict[f'ls_block_aggregation.{level}.norm.weight'] = torch.tensor( + state_dict[f'block_aggs.{level}.norm.weight'] = torch.tensor( flax_dict[f'ConvPool_{level}']['LayerNorm_0']['scale']) - state_dict[f'ls_block_aggregation.{level}.norm.bias'] = torch.tensor( + state_dict[f'block_aggs.{level}.norm.bias'] = torch.tensor( flax_dict[f'ConvPool_{level}']['LayerNorm_0']['bias']) # Final norm @@ -102,6 +104,6 @@ def convert_nest(checkpoint_path, arch): if __name__ == '__main__': - variant = 'base' + variant = sys.argv[1] # base, small, or tiny state_dict = convert_nest(f'../nested-transformer/checkpoints/nest-{variant[0]}_imagenet', f'nest_{variant}') - torch.save(state_dict, f'jx_nest_{variant}.pth') \ No newline at end of file + torch.save(state_dict, f'/home/alexander/.cache/torch/hub/checkpoints/jx_nest_{variant}.pth') \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py index 5c8b02db..fe4fcd9a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,7 +17,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*', 'deit*'] + 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/nest.py b/timm/models/nest.py index 0131b677..d2dc359a 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -277,7 +277,6 @@ class Nest(nn.Module): # Patch embedding self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) - self.feature_info = [dict(num_chs=embed_dims[0], reduction=patch_size, module='patch_embed')] self.num_patches = self.patch_embed.num_patches self.seq_length = self.num_patches // self.num_blocks[0] @@ -291,8 +290,6 @@ class Nest(nn.Module): self.num_blocks[lix], self.block_size, self.seq_length, num_heads[lix], depths[lix], embed_dims[lix], mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dpr, norm_layer, act_layer)) - self.feature_info.append( - dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction']*2, module=f'levels.{lix}')) if lix < self.num_levels - 1: self.block_aggs.append(BlockAggregation( embed_dims[lix], embed_dims[lix+1], norm_layer, pad_type=pad_type)) @@ -302,8 +299,6 @@ class Nest(nn.Module): # Final normalization layer self.norm = norm_layer(embed_dims[-1]) - self.feature_info.append( - dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction'], module='norm')) # Classifier self.global_pool, self.head = create_classifier( @@ -319,7 +314,7 @@ class Nest(nn.Module): if mode.startswith('jax'): named_apply(partial(_init_nest_weights, head_bias=head_bias, jax_impl=True), self) else: - self.apply(_init_nest_weights) + named_apply(_init_nest_weights, self) @torch.jit.ignore def no_weight_decay(self): @@ -409,16 +404,14 @@ def checkpoint_filter_fn(state_dict, model): def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs): - # if kwargs.get('features_only', None): - # raise RuntimeError('features_only not implemented for Vision Transformer models.') + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') default_cfg = default_cfg or default_cfgs[variant] model = build_model_with_cfg( Nest, variant, pretrained, default_cfg=default_cfg, pretrained_filter_fn=checkpoint_filter_fn, - feature_cfg=dict( - out_indices=tuple(range(kwargs.get('num_levels', 3) + 2)), feature_cls='hook', flatten_sequential=True), **kwargs) return model @@ -487,11 +480,4 @@ def jx_nest_tiny(pretrained=False, **kwargs): model_kwargs = dict( embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs) model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs) - return model - - -if __name__ == '__main__': - model = jx_nest_base() - model = torch.jit.script(model) - inp = torch.zeros(8, 3, 224, 224) - print(model.forward_features(inp).shape) \ No newline at end of file + return model \ No newline at end of file