|
|
|
@ -5,7 +5,8 @@ A PyTorch implement of Aggregating Nested Transformers as described in:
|
|
|
|
|
'Aggregating Nested Transformers'
|
|
|
|
|
- https://arxiv.org/abs/2105.12723
|
|
|
|
|
|
|
|
|
|
The official Jax code is released and available at https://github.com/google-research/nested-transformer
|
|
|
|
|
The official Jax code is released and available at https://github.com/google-research/nested-transformer. The weights
|
|
|
|
|
have been converted with convert/convert_nest_flax.py
|
|
|
|
|
|
|
|
|
|
Acknowledgments:
|
|
|
|
|
* The paper authors for sharing their research, code, and model weights
|
|
|
|
@ -37,7 +38,6 @@ from .vision_transformer import resize_pos_embed
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO check first_conv. everything else has been checked
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url,
|
|
|
|
@ -60,7 +60,6 @@ default_cfgs = {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO - Leave note for Ross - Maybe we can generalize Attention to this and put it in layers
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with
|
|
|
|
@ -102,7 +101,6 @@ class TransformerLayer(Block):
|
|
|
|
|
This is much like `.vision_transformer.Block` but:
|
|
|
|
|
- Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
|
|
|
|
|
- Uses modified Attention layer that handles the "block" dimension
|
|
|
|
|
TODO somehow reuse the code instead of rewriting it...
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
|
|
|
|
|
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
@ -172,6 +170,39 @@ def deblockify(x, block_size: int):
|
|
|
|
|
height = width = grid_size * block_size
|
|
|
|
|
x = x.reshape(B, height, width, C)
|
|
|
|
|
return x # (B, H, W, C)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NestLevel(nn.Module):
|
|
|
|
|
""" Single hierarchical level of a Nested Transformer
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, mlp_ratio=4., qkv_bias=True,
|
|
|
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rates=[], norm_layer=None, act_layer=None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.block_size = block_size
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim))
|
|
|
|
|
# Transformer encoder
|
|
|
|
|
if len(drop_path_rates):
|
|
|
|
|
assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers'
|
|
|
|
|
self.transformer_encoder = nn.Sequential(*[
|
|
|
|
|
TransformerLayer(
|
|
|
|
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rates[i],
|
|
|
|
|
norm_layer=norm_layer, act_layer=act_layer)
|
|
|
|
|
for i in range(depth)])
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
"""
|
|
|
|
|
expects x as (B, C, H, W)
|
|
|
|
|
"""
|
|
|
|
|
# Switch to channels last for transformer
|
|
|
|
|
x = x.permute(0, 2, 3, 1) # (B, H', W', C)
|
|
|
|
|
x = blockify(x, self.block_size) # (B, T, N, C')
|
|
|
|
|
x = x + self.pos_embed
|
|
|
|
|
x = self.transformer_encoder(x) # (B, T, N, C')
|
|
|
|
|
x = deblockify(x, self.block_size) # (B, H', W', C')
|
|
|
|
|
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
|
|
|
|
|
x = x.permute(0, 3, 1, 2) # (B, C, H', W')
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Nest(nn.Module):
|
|
|
|
@ -182,10 +213,9 @@ class Nest(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512),
|
|
|
|
|
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4.,
|
|
|
|
|
qkv_bias=True, pad_type='',
|
|
|
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None,
|
|
|
|
|
act_layer=None, weight_init='', global_pool='avg'):
|
|
|
|
|
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, pad_type='',
|
|
|
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, weight_init='',
|
|
|
|
|
global_pool='avg'):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
img_size (int, tuple): input image size
|
|
|
|
@ -203,7 +233,7 @@ class Nest(nn.Module):
|
|
|
|
|
drop_path_rate (float): stochastic depth rate
|
|
|
|
|
norm_layer: (nn.Module): normalization layer for transformer layers
|
|
|
|
|
act_layer: (nn.Module): activation layer in MLP of transformer layers
|
|
|
|
|
weight_init: (str): weight init scheme TODO check
|
|
|
|
|
weight_init: (str): weight init scheme
|
|
|
|
|
global_pool: (str): type of pooling operation to apply to final feature map
|
|
|
|
|
|
|
|
|
|
Notes:
|
|
|
|
@ -247,45 +277,33 @@ class Nest(nn.Module):
|
|
|
|
|
# Patch embedding
|
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
|
|
|
|
|
self.feature_info = [dict(num_chs=embed_dims[0], reduction=patch_size, module='patch_embed')]
|
|
|
|
|
self.num_patches = self.patch_embed.num_patches
|
|
|
|
|
self.seq_length = self.num_patches // self.num_blocks[0]
|
|
|
|
|
|
|
|
|
|
# Build up each hierarchical level
|
|
|
|
|
self.ls_pos_embed = []
|
|
|
|
|
self.ls_transformer_encoder = nn.ModuleList([])
|
|
|
|
|
self.ls_block_aggregation = nn.ModuleList([])
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # drop path rate
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
for level in range(self.num_levels):
|
|
|
|
|
# Positional embedding
|
|
|
|
|
# NOTE: Can't use ParameterList for positional embedding as it can't be enumerated with TorchScript
|
|
|
|
|
pos_embed = nn.Parameter(
|
|
|
|
|
torch.zeros(1, self.num_blocks[level], self.num_patches // self.num_blocks[0], embed_dims[level]))
|
|
|
|
|
self.register_parameter(f'pos_embed_{level}', pos_embed)
|
|
|
|
|
self.ls_pos_embed.append(pos_embed)
|
|
|
|
|
# Transformer encoder
|
|
|
|
|
self.ls_transformer_encoder.append(nn.Sequential(*[
|
|
|
|
|
TransformerLayer(
|
|
|
|
|
dim=embed_dims[level], num_heads=num_heads[level], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:level]) + i],
|
|
|
|
|
norm_layer=norm_layer, act_layer=act_layer)
|
|
|
|
|
for i in range(depths[level])]))
|
|
|
|
|
|
|
|
|
|
self.feature_info.append(dict(
|
|
|
|
|
num_chs=embed_dims[level], reduction=2,
|
|
|
|
|
module=f'ls_transformer_encoder.{level}.{depths[level]-1}.mlp.fc2'))
|
|
|
|
|
|
|
|
|
|
# Block aggregation (not required for last level)
|
|
|
|
|
if level < self.num_levels - 1:
|
|
|
|
|
self.ls_block_aggregation.append(
|
|
|
|
|
BlockAggregation(embed_dims[level], embed_dims[level+1], norm_layer, pad_type=pad_type))
|
|
|
|
|
self.levels = nn.ModuleList([])
|
|
|
|
|
self.block_aggs = nn.ModuleList([])
|
|
|
|
|
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
|
|
|
|
for lix in range(self.num_levels):
|
|
|
|
|
dpr = drop_path_rates[sum(depths[:lix]):sum(depths[:lix+1])]
|
|
|
|
|
self.levels.append(NestLevel(
|
|
|
|
|
self.num_blocks[lix], self.block_size, self.seq_length, num_heads[lix], depths[lix],
|
|
|
|
|
embed_dims[lix], mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dpr, norm_layer,
|
|
|
|
|
act_layer))
|
|
|
|
|
self.feature_info.append(
|
|
|
|
|
dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction']*2, module=f'levels.{lix}'))
|
|
|
|
|
if lix < self.num_levels - 1:
|
|
|
|
|
self.block_aggs.append(BlockAggregation(
|
|
|
|
|
embed_dims[lix], embed_dims[lix+1], norm_layer, pad_type=pad_type))
|
|
|
|
|
else:
|
|
|
|
|
# NOTE: Required for enumeration over all level components at once
|
|
|
|
|
self.ls_block_aggregation.append(nn.Identity())
|
|
|
|
|
self.ls_pos_embed = tuple(self.ls_pos_embed) # static length required for torchscript
|
|
|
|
|
# Required for zipped iteration over levels and ls_block_agg together
|
|
|
|
|
self.block_aggs.append(nn.Identity())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Final normalization layer
|
|
|
|
|
self.norm = norm_layer(embed_dims[-1])
|
|
|
|
|
self.feature_info.append(
|
|
|
|
|
dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction'], module='norm'))
|
|
|
|
|
|
|
|
|
|
# Classifier
|
|
|
|
|
self.global_pool, self.head = create_classifier(
|
|
|
|
@ -296,8 +314,8 @@ class Nest(nn.Module):
|
|
|
|
|
def init_weights(self, mode=''):
|
|
|
|
|
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
|
|
|
|
|
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
|
|
|
|
for pos_embed in self.ls_pos_embed:
|
|
|
|
|
trunc_normal_(pos_embed, std=.02, a=-2, b=2)
|
|
|
|
|
for level in self.levels:
|
|
|
|
|
trunc_normal_(level.pos_embed, std=.02, a=-2, b=2)
|
|
|
|
|
if mode.startswith('jax'):
|
|
|
|
|
named_apply(partial(_init_nest_weights, head_bias=head_bias, jax_impl=True), self)
|
|
|
|
|
else:
|
|
|
|
@ -319,22 +337,13 @@ class Nest(nn.Module):
|
|
|
|
|
""" x shape (B, C, H, W)
|
|
|
|
|
"""
|
|
|
|
|
B, _, H, W = x.shape
|
|
|
|
|
x = self.patch_embed(x) # (B, N, C)
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
x = x.reshape(B, H//self.patch_size, W//self.patch_size, -1) # (B, H', W', C')
|
|
|
|
|
# NOTE: TorchScript wants enumeration rather than subscripting of ModuleList
|
|
|
|
|
for level, (pos_embed, transformer, block_agg) in enumerate(
|
|
|
|
|
zip(self.ls_pos_embed, self.ls_transformer_encoder, self.ls_block_aggregation)):
|
|
|
|
|
if level > 0:
|
|
|
|
|
# Switch back to channels last for transformer
|
|
|
|
|
x = x.permute(0, 2, 3, 1) # (B, H', W', C)
|
|
|
|
|
x = blockify(x, self.block_size) # (B, T, N, C')
|
|
|
|
|
x = x + pos_embed
|
|
|
|
|
x = transformer(x) # (B, T, N, C')
|
|
|
|
|
x = deblockify(x, self.block_size) # (B, H', W', C')
|
|
|
|
|
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
|
|
|
|
|
x = x.permute(0, 3, 1, 2) # (B, C, H', W')
|
|
|
|
|
if level < self.num_levels - 1:
|
|
|
|
|
x = block_agg(x) # (B, C', H'//2, W'//2)
|
|
|
|
|
x = x.permute(0, 3, 1, 2)
|
|
|
|
|
# NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead
|
|
|
|
|
for level, block_agg in zip(self.levels, self.block_aggs):
|
|
|
|
|
x = level(x)
|
|
|
|
|
x = block_agg(x)
|
|
|
|
|
# Layer norm done over channel dim only
|
|
|
|
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
return x
|
|
|
|
@ -404,11 +413,12 @@ def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs):
|
|
|
|
|
# raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
|
|
|
|
|
|
|
|
|
default_cfg = default_cfg or default_cfgs[variant]
|
|
|
|
|
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
Nest, variant, pretrained,
|
|
|
|
|
default_cfg=default_cfg,
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
feature_cfg=dict(
|
|
|
|
|
out_indices=tuple(range(kwargs.get('num_levels', 3) + 2)), feature_cls='hook', flatten_sequential=True),
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
@ -478,3 +488,10 @@ def jx_nest_tiny(pretrained=False, **kwargs):
|
|
|
|
|
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs)
|
|
|
|
|
model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
model = jx_nest_base()
|
|
|
|
|
model = torch.jit.script(model)
|
|
|
|
|
inp = torch.zeros(8, 3, 224, 224)
|
|
|
|
|
print(model.forward_features(inp).shape)
|