From 81cd6863c8c9515de8884e8a8ea0445ec08b4486 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 5 Jul 2021 18:20:49 -0700 Subject: [PATCH] Move aggregation (convpool) for nest into NestLevel, cleanup and enable features_only use. Finalize weight url. --- README.md | 3 + convert/convert_nest_flax.py | 24 ++--- timm/models/nest.py | 180 ++++++++++++++++------------------- 3 files changed, 95 insertions(+), 112 deletions(-) diff --git a/README.md b/README.md index 168f254a..52acfbb9 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,9 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### July 5, 2021 +* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare). + ### June 23, 2021 * Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6) diff --git a/convert/convert_nest_flax.py b/convert/convert_nest_flax.py index 513aa803..cda4d34f 100644 --- a/convert/convert_nest_flax.py +++ b/convert/convert_nest_flax.py @@ -79,18 +79,18 @@ def convert_nest(checkpoint_path, arch): state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.bias'] = torch.tensor( flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['bias']) - # Block aggregations - for level in range(len(depths)-1): + # Block aggregations (ConvPool) + for level in range(1, len(depths)): # Convs - state_dict[f'block_aggs.{level}.conv.weight'] = torch.tensor( - flax_dict[f'ConvPool_{level}']['Conv_0']['kernel']).permute(3, 2, 0, 1) - state_dict[f'block_aggs.{level}.conv.bias'] = torch.tensor( - flax_dict[f'ConvPool_{level}']['Conv_0']['bias']) + state_dict[f'levels.{level}.pool.conv.weight'] = torch.tensor( + flax_dict[f'ConvPool_{level-1}']['Conv_0']['kernel']).permute(3, 2, 0, 1) + state_dict[f'levels.{level}.pool.conv.bias'] = torch.tensor( + flax_dict[f'ConvPool_{level-1}']['Conv_0']['bias']) # Norms - state_dict[f'block_aggs.{level}.norm.weight'] = torch.tensor( - flax_dict[f'ConvPool_{level}']['LayerNorm_0']['scale']) - state_dict[f'block_aggs.{level}.norm.bias'] = torch.tensor( - flax_dict[f'ConvPool_{level}']['LayerNorm_0']['bias']) + state_dict[f'levels.{level}.pool.norm.weight'] = torch.tensor( + flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['scale']) + state_dict[f'levels.{level}.pool.norm.bias'] = torch.tensor( + flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['bias']) # Final norm state_dict[f'norm.weight'] = torch.tensor(flax_dict['LayerNorm_0']['scale']) @@ -105,5 +105,5 @@ def convert_nest(checkpoint_path, arch): if __name__ == '__main__': variant = sys.argv[1] # base, small, or tiny - state_dict = convert_nest(f'../nested-transformer/checkpoints/nest-{variant[0]}_imagenet', f'nest_{variant}') - torch.save(state_dict, f'/home/alexander/.cache/torch/hub/checkpoints/jx_nest_{variant}.pth') \ No newline at end of file + state_dict = convert_nest(f'./nest-{variant[0]}_imagenet', f'nest_{variant}') + torch.save(state_dict, f'./jx_nest_{variant}.pth') \ No newline at end of file diff --git a/timm/models/nest.py b/timm/models/nest.py index d2dc359a..601b6104 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -16,24 +16,19 @@ Copyright 2021 Alexander Soare """ import collections.abc -from functools import partial -import math import logging +import math +from functools import partial -import numpy as np import torch -from torch import nn import torch.nn.functional as F +from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ -from .layers.helpers import to_ntuple -from .layers.create_conv2d import create_conv2d -from .layers.pool2d_same import create_pool2d -from .vision_transformer import Block +from .layers import create_conv2d, create_pool2d, to_ntuple from .registry import register_model -from .helpers import build_model_with_cfg, named_apply -from .vision_transformer import resize_pos_embed _logger = logging.getLogger(__name__) @@ -54,9 +49,12 @@ default_cfgs = { 'nest_base': _cfg(), 'nest_small': _cfg(), 'nest_tiny': _cfg(), - 'jx_nest_base': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_base.pth'), # TODO - 'jx_nest_small': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_small.pth'), # TODO - 'jx_nest_tiny': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_tiny.pth'), # TODO + 'jx_nest_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth'), + 'jx_nest_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth'), + 'jx_nest_tiny': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth'), } @@ -93,10 +91,10 @@ class Attention(nn.Module): x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) x = self.proj(x) x = self.proj_drop(x) - return x # (B, T, N, C) + return x # (B, T, N, C) -class TransformerLayer(Block): +class TransformerLayer(nn.Module): """ This is much like `.vision_transformer.Block` but: - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks") @@ -104,8 +102,7 @@ class TransformerLayer(Block): """ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__(dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm) + super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -120,7 +117,7 @@ class TransformerLayer(Block): return x -class BlockAggregation(nn.Module): +class ConvPool(nn.Module): def __init__(self, in_channels, out_channels, norm_layer, pad_type=''): super().__init__() self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True) @@ -137,7 +134,7 @@ class BlockAggregation(nn.Module): # Layer norm done over channel dim only x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.pool(x) - return x # (B, C, H//2, W//2) + return x # (B, C, H//2, W//2) def blockify(x, block_size: int): @@ -152,9 +149,8 @@ def blockify(x, block_size: int): grid_height = H // block_size grid_width = W // block_size x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) - x = x.permute(0, 1, 3, 2, 4, 5) - x = x.reshape(B, grid_height * grid_width, -1, C) - return x # (B, T, N, C) + x = x.transpose(2, 3).reshape(B, grid_height * grid_width, -1, C) + return x # (B, T, N, C) def deblockify(x, block_size: int): @@ -163,23 +159,30 @@ def deblockify(x, block_size: int): x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block block_size (int): edge length of a single square block in units of desired H, W """ - B, T, _, C= x.shape + B, T, _, C = x.shape grid_size = int(math.sqrt(T)) - x = x.reshape(B, grid_size, grid_size, block_size, block_size, C) - x = x.permute(0, 1, 3, 2, 4, 5) height = width = grid_size * block_size - x = x.reshape(B, height, width, C) - return x # (B, H, W, C) - + x = x.reshape(B, grid_size, grid_size, block_size, block_size, C) + x = x.transpose(2, 3).reshape(B, height, width, C) + return x # (B, H, W, C) + class NestLevel(nn.Module): """ Single hierarchical level of a Nested Transformer """ - def __init__(self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rates=[], norm_layer=None, act_layer=None): + def __init__( + self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, prev_embed_dim=None, + mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rates=[], + norm_layer=None, act_layer=None, pad_type=''): super().__init__() self.block_size = block_size self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim)) + + if prev_embed_dim is not None: + self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type) + else: + self.pool = nn.Identity() + # Transformer encoder if len(drop_path_rates): assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers' @@ -194,15 +197,14 @@ class NestLevel(nn.Module): """ expects x as (B, C, H, W) """ - # Switch to channels last for transformer - x = x.permute(0, 2, 3, 1) # (B, H', W', C) - x = blockify(x, self.block_size) # (B, T, N, C') + x = self.pool(x) + x = x.permute(0, 2, 3, 1) # (B, H', W', C), switch to channels last for transformer + x = blockify(x, self.block_size) # (B, T, N, C') x = x + self.pos_embed - x = self.transformer_encoder(x) # (B, T, N, C') - x = deblockify(x, self.block_size) # (B, H', W', C') + x = self.transformer_encoder(x) # (B, T, N, C') + x = deblockify(x, self.block_size) # (B, H', W', C') # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage - x = x.permute(0, 3, 1, 2) # (B, C, H', W') - return x + return x.permute(0, 3, 1, 2) # (B, C, H', W') class Nest(nn.Module): @@ -213,9 +215,9 @@ class Nest(nn.Module): """ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512), - num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, pad_type='', - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, weight_init='', - global_pool='avg'): + num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, + pad_type='', weight_init='', global_pool='avg'): """ Args: img_size (int, tuple): input image size @@ -233,6 +235,7 @@ class Nest(nn.Module): drop_path_rate (float): stochastic depth rate norm_layer: (nn.Module): normalization layer for transformer layers act_layer: (nn.Module): activation layer in MLP of transformer layers + pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME weight_init: (str): weight init scheme global_pool: (str): type of pooling operation to apply to final feature map @@ -254,6 +257,7 @@ class Nest(nn.Module): depths = to_ntuple(num_levels)(depths) self.num_classes = num_classes self.num_features = embed_dims[-1] + self.feature_info = [] norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.drop_rate = drop_rate @@ -265,60 +269,54 @@ class Nest(nn.Module): self.patch_size = patch_size # Number of blocks at each level - self.num_blocks = 4**(np.arange(num_levels)[::-1]) - assert (img_size // patch_size) % np.sqrt(self.num_blocks[0]) == 0, \ - 'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`' + self.num_blocks = (4 ** torch.arange(num_levels)).flip(0).tolist() + assert (img_size // patch_size) % math.sqrt(self.num_blocks[0]) == 0, \ + 'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`' # Block edge size in units of patches # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the # number of blocks along edge of image - self.block_size = int((img_size // patch_size) // np.sqrt(self.num_blocks[0])) + self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0])) # Patch embedding self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], flatten=False) self.num_patches = self.patch_embed.num_patches self.seq_length = self.num_patches // self.num_blocks[0] # Build up each hierarchical level - self.levels = nn.ModuleList([]) - self.block_aggs = nn.ModuleList([]) - drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - for lix in range(self.num_levels): - dpr = drop_path_rates[sum(depths[:lix]):sum(depths[:lix+1])] - self.levels.append(NestLevel( - self.num_blocks[lix], self.block_size, self.seq_length, num_heads[lix], depths[lix], - embed_dims[lix], mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dpr, norm_layer, - act_layer)) - if lix < self.num_levels - 1: - self.block_aggs.append(BlockAggregation( - embed_dims[lix], embed_dims[lix+1], norm_layer, pad_type=pad_type)) - else: - # Required for zipped iteration over levels and ls_block_agg together - self.block_aggs.append(nn.Identity()) + levels = [] + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + prev_dim = None + curr_stride = 4 + for i in range(len(self.num_blocks)): + dim = embed_dims[i] + levels.append(NestLevel( + self.num_blocks[i], self.block_size, self.seq_length, num_heads[i], depths[i], dim, prev_dim, + mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dp_rates[i], norm_layer, act_layer, pad_type=pad_type)) + self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')] + prev_dim = dim + curr_stride *= 2 + self.levels = nn.Sequential(*levels) # Final normalization layer self.norm = norm_layer(embed_dims[-1]) # Classifier - self.global_pool, self.head = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) self.init_weights(weight_init) def init_weights(self, mode=''): - assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + assert mode in ('nlhb', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. for level in self.levels: trunc_normal_(level.pos_embed, std=.02, a=-2, b=2) - if mode.startswith('jax'): - named_apply(partial(_init_nest_weights, head_bias=head_bias, jax_impl=True), self) - else: - named_apply(_init_nest_weights, self) + named_apply(partial(_init_nest_weights, head_bias=head_bias), self) @torch.jit.ignore def no_weight_decay(self): - return {'pos_embed'} + return {f'level.{i}.pos_embed' for i in range(len(self.levels))} def get_classifier(self): return self.head @@ -333,13 +331,8 @@ class Nest(nn.Module): """ B, _, H, W = x.shape x = self.patch_embed(x) - x = x.reshape(B, H//self.patch_size, W//self.patch_size, -1) # (B, H', W', C') - x = x.permute(0, 3, 1, 2) - # NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead - for level, block_agg in zip(self.levels, self.block_aggs): - x = level(x) - x = block_agg(x) - # Layer norm done over channel dim only + x = self.levels(x) + # Layer norm done over channel dim only (to NHWC and back) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x @@ -353,22 +346,19 @@ class Nest(nn.Module): return self.head(x) -def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): +def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.): """ NesT weight initialization Can replicate Jax implementation. Otherwise follows vision_transformer.py """ if isinstance(module, nn.Linear): if name.startswith('head'): - if jax_impl: - trunc_normal_(module.weight, std=.02, a=-2, b=2) - else: - nn.init.zeros_(module.weight) + trunc_normal_(module.weight, std=.02, a=-2, b=2) nn.init.constant_(module.bias, head_bias) else: trunc_normal_(module.weight, std=.02, a=-2, b=2) if module.bias is not None: - nn.init.zeros_(module.bias) - elif jax_impl and isinstance(module, nn.Conv2d): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): trunc_normal_(module.weight, std=.02, a=-2, b=2) if module.bias is not None: nn.init.zeros_(module.bias) @@ -404,13 +394,11 @@ def checkpoint_filter_fn(state_dict, model): def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - default_cfg = default_cfg or default_cfgs[variant] model = build_model_with_cfg( Nest, variant, pretrained, default_cfg=default_cfg, + feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True), pretrained_filter_fn=checkpoint_filter_fn, **kwargs) @@ -422,7 +410,7 @@ def nest_base(pretrained=False, **kwargs): """ Nest-B @ 224x224 """ model_kwargs = dict( - embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), drop_path_rate=0.5, **kwargs) + embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs) model = _create_nest('nest_base', pretrained=pretrained, **model_kwargs) return model @@ -431,8 +419,7 @@ def nest_base(pretrained=False, **kwargs): def nest_small(pretrained=False, **kwargs): """ Nest-S @ 224x224 """ - model_kwargs = dict( - embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), drop_path_rate=0.3, **kwargs) + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs) model = _create_nest('nest_small', pretrained=pretrained, **model_kwargs) return model @@ -441,8 +428,7 @@ def nest_small(pretrained=False, **kwargs): def nest_tiny(pretrained=False, **kwargs): """ Nest-T @ 224x224 """ - model_kwargs = dict( - embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs) + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs) model = _create_nest('nest_tiny', pretrained=pretrained, **model_kwargs) return model @@ -452,9 +438,7 @@ def jx_nest_base(pretrained=False, **kwargs): """ Nest-B @ 224x224, Pretrained weights converted from official Jax impl. """ kwargs['pad_type'] = 'same' - kwargs['weight_init'] = 'jax' - model_kwargs = dict( - embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), drop_path_rate=0.5, **kwargs) + model_kwargs = dict(embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs) model = _create_nest('jx_nest_base', pretrained=pretrained, **model_kwargs) return model @@ -464,9 +448,7 @@ def jx_nest_small(pretrained=False, **kwargs): """ Nest-S @ 224x224, Pretrained weights converted from official Jax impl. """ kwargs['pad_type'] = 'same' - kwargs['weight_init'] = 'jax' - model_kwargs = dict( - embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), drop_path_rate=0.3, **kwargs) + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs) model = _create_nest('jx_nest_small', pretrained=pretrained, **model_kwargs) return model @@ -476,8 +458,6 @@ def jx_nest_tiny(pretrained=False, **kwargs): """ Nest-T @ 224x224, Pretrained weights converted from official Jax impl. """ kwargs['pad_type'] = 'same' - kwargs['weight_init'] = 'jax' - model_kwargs = dict( - embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs) + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs) model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs) - return model \ No newline at end of file + return model