wip checkpoint with some feature extraction work

pull/731/head
Alexander Soare 3 years ago
parent 23bb72ce5e
commit b11d949a06

@ -5,7 +5,8 @@ A PyTorch implement of Aggregating Nested Transformers as described in:
'Aggregating Nested Transformers'
- https://arxiv.org/abs/2105.12723
The official Jax code is released and available at https://github.com/google-research/nested-transformer
The official Jax code is released and available at https://github.com/google-research/nested-transformer. The weights
have been converted with convert/convert_nest_flax.py
Acknowledgments:
* The paper authors for sharing their research, code, and model weights
@ -37,7 +38,6 @@ from .vision_transformer import resize_pos_embed
_logger = logging.getLogger(__name__)
# TODO check first_conv. everything else has been checked
def _cfg(url='', **kwargs):
return {
'url': url,
@ -60,7 +60,6 @@ default_cfgs = {
}
# TODO - Leave note for Ross - Maybe we can generalize Attention to this and put it in layers
class Attention(nn.Module):
"""
This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with
@ -102,7 +101,6 @@ class TransformerLayer(Block):
This is much like `.vision_transformer.Block` but:
- Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
- Uses modified Attention layer that handles the "block" dimension
TODO somehow reuse the code instead of rewriting it...
"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
@ -172,6 +170,39 @@ def deblockify(x, block_size: int):
height = width = grid_size * block_size
x = x.reshape(B, height, width, C)
return x # (B, H, W, C)
class NestLevel(nn.Module):
""" Single hierarchical level of a Nested Transformer
"""
def __init__(self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rates=[], norm_layer=None, act_layer=None):
super().__init__()
self.block_size = block_size
self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim))
# Transformer encoder
if len(drop_path_rates):
assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers'
self.transformer_encoder = nn.Sequential(*[
TransformerLayer(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rates[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)])
def forward(self, x):
"""
expects x as (B, C, H, W)
"""
# Switch to channels last for transformer
x = x.permute(0, 2, 3, 1) # (B, H', W', C)
x = blockify(x, self.block_size) # (B, T, N, C')
x = x + self.pos_embed
x = self.transformer_encoder(x) # (B, T, N, C')
x = deblockify(x, self.block_size) # (B, H', W', C')
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
x = x.permute(0, 3, 1, 2) # (B, C, H', W')
return x
class Nest(nn.Module):
@ -182,10 +213,9 @@ class Nest(nn.Module):
"""
def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512),
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4.,
qkv_bias=True, pad_type='',
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None,
act_layer=None, weight_init='', global_pool='avg'):
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, pad_type='',
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, weight_init='',
global_pool='avg'):
"""
Args:
img_size (int, tuple): input image size
@ -203,7 +233,7 @@ class Nest(nn.Module):
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer for transformer layers
act_layer: (nn.Module): activation layer in MLP of transformer layers
weight_init: (str): weight init scheme TODO check
weight_init: (str): weight init scheme
global_pool: (str): type of pooling operation to apply to final feature map
Notes:
@ -247,45 +277,33 @@ class Nest(nn.Module):
# Patch embedding
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
self.feature_info = [dict(num_chs=embed_dims[0], reduction=patch_size, module='patch_embed')]
self.num_patches = self.patch_embed.num_patches
self.seq_length = self.num_patches // self.num_blocks[0]
# Build up each hierarchical level
self.ls_pos_embed = []
self.ls_transformer_encoder = nn.ModuleList([])
self.ls_block_aggregation = nn.ModuleList([])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # drop path rate
self.feature_info = []
for level in range(self.num_levels):
# Positional embedding
# NOTE: Can't use ParameterList for positional embedding as it can't be enumerated with TorchScript
pos_embed = nn.Parameter(
torch.zeros(1, self.num_blocks[level], self.num_patches // self.num_blocks[0], embed_dims[level]))
self.register_parameter(f'pos_embed_{level}', pos_embed)
self.ls_pos_embed.append(pos_embed)
# Transformer encoder
self.ls_transformer_encoder.append(nn.Sequential(*[
TransformerLayer(
dim=embed_dims[level], num_heads=num_heads[level], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:level]) + i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depths[level])]))
self.feature_info.append(dict(
num_chs=embed_dims[level], reduction=2,
module=f'ls_transformer_encoder.{level}.{depths[level]-1}.mlp.fc2'))
# Block aggregation (not required for last level)
if level < self.num_levels - 1:
self.ls_block_aggregation.append(
BlockAggregation(embed_dims[level], embed_dims[level+1], norm_layer, pad_type=pad_type))
self.levels = nn.ModuleList([])
self.block_aggs = nn.ModuleList([])
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
for lix in range(self.num_levels):
dpr = drop_path_rates[sum(depths[:lix]):sum(depths[:lix+1])]
self.levels.append(NestLevel(
self.num_blocks[lix], self.block_size, self.seq_length, num_heads[lix], depths[lix],
embed_dims[lix], mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dpr, norm_layer,
act_layer))
self.feature_info.append(
dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction']*2, module=f'levels.{lix}'))
if lix < self.num_levels - 1:
self.block_aggs.append(BlockAggregation(
embed_dims[lix], embed_dims[lix+1], norm_layer, pad_type=pad_type))
else:
# NOTE: Required for enumeration over all level components at once
self.ls_block_aggregation.append(nn.Identity())
self.ls_pos_embed = tuple(self.ls_pos_embed) # static length required for torchscript
# Required for zipped iteration over levels and ls_block_agg together
self.block_aggs.append(nn.Identity())
# Final normalization layer
self.norm = norm_layer(embed_dims[-1])
self.feature_info.append(
dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction'], module='norm'))
# Classifier
self.global_pool, self.head = create_classifier(
@ -296,8 +314,8 @@ class Nest(nn.Module):
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
for pos_embed in self.ls_pos_embed:
trunc_normal_(pos_embed, std=.02, a=-2, b=2)
for level in self.levels:
trunc_normal_(level.pos_embed, std=.02, a=-2, b=2)
if mode.startswith('jax'):
named_apply(partial(_init_nest_weights, head_bias=head_bias, jax_impl=True), self)
else:
@ -319,22 +337,13 @@ class Nest(nn.Module):
""" x shape (B, C, H, W)
"""
B, _, H, W = x.shape
x = self.patch_embed(x) # (B, N, C)
x = self.patch_embed(x)
x = x.reshape(B, H//self.patch_size, W//self.patch_size, -1) # (B, H', W', C')
# NOTE: TorchScript wants enumeration rather than subscripting of ModuleList
for level, (pos_embed, transformer, block_agg) in enumerate(
zip(self.ls_pos_embed, self.ls_transformer_encoder, self.ls_block_aggregation)):
if level > 0:
# Switch back to channels last for transformer
x = x.permute(0, 2, 3, 1) # (B, H', W', C)
x = blockify(x, self.block_size) # (B, T, N, C')
x = x + pos_embed
x = transformer(x) # (B, T, N, C')
x = deblockify(x, self.block_size) # (B, H', W', C')
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
x = x.permute(0, 3, 1, 2) # (B, C, H', W')
if level < self.num_levels - 1:
x = block_agg(x) # (B, C', H'//2, W'//2)
x = x.permute(0, 3, 1, 2)
# NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead
for level, block_agg in zip(self.levels, self.block_aggs):
x = level(x)
x = block_agg(x)
# Layer norm done over channel dim only
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
@ -404,11 +413,12 @@ def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs):
# raise RuntimeError('features_only not implemented for Vision Transformer models.')
default_cfg = default_cfg or default_cfgs[variant]
model = build_model_with_cfg(
Nest, variant, pretrained,
default_cfg=default_cfg,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(
out_indices=tuple(range(kwargs.get('num_levels', 3) + 2)), feature_cls='hook', flatten_sequential=True),
**kwargs)
return model
@ -478,3 +488,10 @@ def jx_nest_tiny(pretrained=False, **kwargs):
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs)
model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs)
return model
if __name__ == '__main__':
model = jx_nest_base()
model = torch.jit.script(model)
inp = torch.zeros(8, 3, 224, 224)
print(model.forward_features(inp).shape)
Loading…
Cancel
Save