From b11d949a06f4c19be616aba25e49bf33d13a6b23 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 3 Jul 2021 11:45:19 +0100 Subject: [PATCH] wip checkpoint with some feature extraction work --- timm/models/nest.py | 135 +++++++++++++++++++++++++------------------- 1 file changed, 76 insertions(+), 59 deletions(-) diff --git a/timm/models/nest.py b/timm/models/nest.py index dcf664c6..0131b677 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -5,7 +5,8 @@ A PyTorch implement of Aggregating Nested Transformers as described in: 'Aggregating Nested Transformers' - https://arxiv.org/abs/2105.12723 -The official Jax code is released and available at https://github.com/google-research/nested-transformer +The official Jax code is released and available at https://github.com/google-research/nested-transformer. The weights +have been converted with convert/convert_nest_flax.py Acknowledgments: * The paper authors for sharing their research, code, and model weights @@ -37,7 +38,6 @@ from .vision_transformer import resize_pos_embed _logger = logging.getLogger(__name__) -# TODO check first_conv. everything else has been checked def _cfg(url='', **kwargs): return { 'url': url, @@ -60,7 +60,6 @@ default_cfgs = { } -# TODO - Leave note for Ross - Maybe we can generalize Attention to this and put it in layers class Attention(nn.Module): """ This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with @@ -102,7 +101,6 @@ class TransformerLayer(Block): This is much like `.vision_transformer.Block` but: - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks") - Uses modified Attention layer that handles the "block" dimension - TODO somehow reuse the code instead of rewriting it... """ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): @@ -172,6 +170,39 @@ def deblockify(x, block_size: int): height = width = grid_size * block_size x = x.reshape(B, height, width, C) return x # (B, H, W, C) + + +class NestLevel(nn.Module): + """ Single hierarchical level of a Nested Transformer + """ + def __init__(self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rates=[], norm_layer=None, act_layer=None): + super().__init__() + self.block_size = block_size + self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim)) + # Transformer encoder + if len(drop_path_rates): + assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers' + self.transformer_encoder = nn.Sequential(*[ + TransformerLayer( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rates[i], + norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + + def forward(self, x): + """ + expects x as (B, C, H, W) + """ + # Switch to channels last for transformer + x = x.permute(0, 2, 3, 1) # (B, H', W', C) + x = blockify(x, self.block_size) # (B, T, N, C') + x = x + self.pos_embed + x = self.transformer_encoder(x) # (B, T, N, C') + x = deblockify(x, self.block_size) # (B, H', W', C') + # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage + x = x.permute(0, 3, 1, 2) # (B, C, H', W') + return x class Nest(nn.Module): @@ -182,10 +213,9 @@ class Nest(nn.Module): """ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512), - num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., - qkv_bias=True, pad_type='', - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, - act_layer=None, weight_init='', global_pool='avg'): + num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, pad_type='', + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, weight_init='', + global_pool='avg'): """ Args: img_size (int, tuple): input image size @@ -203,7 +233,7 @@ class Nest(nn.Module): drop_path_rate (float): stochastic depth rate norm_layer: (nn.Module): normalization layer for transformer layers act_layer: (nn.Module): activation layer in MLP of transformer layers - weight_init: (str): weight init scheme TODO check + weight_init: (str): weight init scheme global_pool: (str): type of pooling operation to apply to final feature map Notes: @@ -247,45 +277,33 @@ class Nest(nn.Module): # Patch embedding self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) + self.feature_info = [dict(num_chs=embed_dims[0], reduction=patch_size, module='patch_embed')] self.num_patches = self.patch_embed.num_patches + self.seq_length = self.num_patches // self.num_blocks[0] # Build up each hierarchical level - self.ls_pos_embed = [] - self.ls_transformer_encoder = nn.ModuleList([]) - self.ls_block_aggregation = nn.ModuleList([]) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # drop path rate - self.feature_info = [] - for level in range(self.num_levels): - # Positional embedding - # NOTE: Can't use ParameterList for positional embedding as it can't be enumerated with TorchScript - pos_embed = nn.Parameter( - torch.zeros(1, self.num_blocks[level], self.num_patches // self.num_blocks[0], embed_dims[level])) - self.register_parameter(f'pos_embed_{level}', pos_embed) - self.ls_pos_embed.append(pos_embed) - # Transformer encoder - self.ls_transformer_encoder.append(nn.Sequential(*[ - TransformerLayer( - dim=embed_dims[level], num_heads=num_heads[level], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:level]) + i], - norm_layer=norm_layer, act_layer=act_layer) - for i in range(depths[level])])) - - self.feature_info.append(dict( - num_chs=embed_dims[level], reduction=2, - module=f'ls_transformer_encoder.{level}.{depths[level]-1}.mlp.fc2')) - - # Block aggregation (not required for last level) - if level < self.num_levels - 1: - self.ls_block_aggregation.append( - BlockAggregation(embed_dims[level], embed_dims[level+1], norm_layer, pad_type=pad_type)) + self.levels = nn.ModuleList([]) + self.block_aggs = nn.ModuleList([]) + drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + for lix in range(self.num_levels): + dpr = drop_path_rates[sum(depths[:lix]):sum(depths[:lix+1])] + self.levels.append(NestLevel( + self.num_blocks[lix], self.block_size, self.seq_length, num_heads[lix], depths[lix], + embed_dims[lix], mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dpr, norm_layer, + act_layer)) + self.feature_info.append( + dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction']*2, module=f'levels.{lix}')) + if lix < self.num_levels - 1: + self.block_aggs.append(BlockAggregation( + embed_dims[lix], embed_dims[lix+1], norm_layer, pad_type=pad_type)) else: - # NOTE: Required for enumeration over all level components at once - self.ls_block_aggregation.append(nn.Identity()) - self.ls_pos_embed = tuple(self.ls_pos_embed) # static length required for torchscript + # Required for zipped iteration over levels and ls_block_agg together + self.block_aggs.append(nn.Identity()) - # Final normalization layer self.norm = norm_layer(embed_dims[-1]) + self.feature_info.append( + dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction'], module='norm')) # Classifier self.global_pool, self.head = create_classifier( @@ -296,8 +314,8 @@ class Nest(nn.Module): def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'nlhb', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. - for pos_embed in self.ls_pos_embed: - trunc_normal_(pos_embed, std=.02, a=-2, b=2) + for level in self.levels: + trunc_normal_(level.pos_embed, std=.02, a=-2, b=2) if mode.startswith('jax'): named_apply(partial(_init_nest_weights, head_bias=head_bias, jax_impl=True), self) else: @@ -319,22 +337,13 @@ class Nest(nn.Module): """ x shape (B, C, H, W) """ B, _, H, W = x.shape - x = self.patch_embed(x) # (B, N, C) + x = self.patch_embed(x) x = x.reshape(B, H//self.patch_size, W//self.patch_size, -1) # (B, H', W', C') - # NOTE: TorchScript wants enumeration rather than subscripting of ModuleList - for level, (pos_embed, transformer, block_agg) in enumerate( - zip(self.ls_pos_embed, self.ls_transformer_encoder, self.ls_block_aggregation)): - if level > 0: - # Switch back to channels last for transformer - x = x.permute(0, 2, 3, 1) # (B, H', W', C) - x = blockify(x, self.block_size) # (B, T, N, C') - x = x + pos_embed - x = transformer(x) # (B, T, N, C') - x = deblockify(x, self.block_size) # (B, H', W', C') - # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage - x = x.permute(0, 3, 1, 2) # (B, C, H', W') - if level < self.num_levels - 1: - x = block_agg(x) # (B, C', H'//2, W'//2) + x = x.permute(0, 3, 1, 2) + # NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead + for level, block_agg in zip(self.levels, self.block_aggs): + x = level(x) + x = block_agg(x) # Layer norm done over channel dim only x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x @@ -404,11 +413,12 @@ def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs): # raise RuntimeError('features_only not implemented for Vision Transformer models.') default_cfg = default_cfg or default_cfgs[variant] - model = build_model_with_cfg( Nest, variant, pretrained, default_cfg=default_cfg, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict( + out_indices=tuple(range(kwargs.get('num_levels', 3) + 2)), feature_cls='hook', flatten_sequential=True), **kwargs) return model @@ -478,3 +488,10 @@ def jx_nest_tiny(pretrained=False, **kwargs): embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs) model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs) return model + + +if __name__ == '__main__': + model = jx_nest_base() + model = torch.jit.script(model) + inp = torch.zeros(8, 3, 224, 224) + print(model.forward_features(inp).shape) \ No newline at end of file