wip to review

pull/731/head
Alexander Soare 3 years ago
parent b11d949a06
commit 7b8a0017f1

@ -3,6 +3,8 @@ Convert weights from https://github.com/google-research/nested-transformer
NOTE: You'll need https://github.com/google/CommonLoopUtils, not included in requirements.txt
"""
import sys
import numpy as np
import torch
@ -38,7 +40,7 @@ def convert_nest(checkpoint_path, arch):
# Positional embeddings
posemb_keys = [k for k in flax_dict.keys() if k.startswith('PositionEmbedding')]
for i, k in enumerate(posemb_keys):
state_dict[f'pos_embed_{i}'] = torch.tensor(flax_dict[k]['pos_embedding'])
state_dict[f'levels.{i}.pos_embed'] = torch.tensor(flax_dict[k]['pos_embedding'])
# Transformer encoders
depths = arch_depths[arch]
@ -47,9 +49,9 @@ def convert_nest(checkpoint_path, arch):
global_layer_ix = sum(depths[:level]) + layer
# Norms
for i in range(2):
state_dict[f'ls_transformer_encoder.{level}.{layer}.norm{i+1}.weight'] = torch.tensor(
state_dict[f'levels.{level}.transformer_encoder.{layer}.norm{i+1}.weight'] = torch.tensor(
flax_dict[f'EncoderNDBlock_{global_layer_ix}'][f'LayerNorm_{i}']['scale'])
state_dict[f'ls_transformer_encoder.{level}.{layer}.norm{i+1}.bias'] = torch.tensor(
state_dict[f'levels.{level}.transformer_encoder.{layer}.norm{i+1}.bias'] = torch.tensor(
flax_dict[f'EncoderNDBlock_{global_layer_ix}'][f'LayerNorm_{i}']['bias'])
# Attention qkv
w_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_0']['kernel']
@ -57,37 +59,37 @@ def convert_nest(checkpoint_path, arch):
# Pay attention to dims here (maybe get pen and paper)
w_kv = np.concatenate(np.split(w_kv, 2, -1), 1)
w_qkv = np.concatenate([w_q, w_kv], 1)
state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.qkv.weight'] = torch.tensor(w_qkv).flatten(1).permute(1,0)
state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.qkv.weight'] = torch.tensor(w_qkv).flatten(1).permute(1,0)
b_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_0']['bias']
b_kv = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_1']['bias']
# Pay attention to dims here (maybe get pen and paper)
b_kv = np.concatenate(np.split(b_kv, 2, -1), 0)
b_qkv = np.concatenate([b_q, b_kv], 0)
state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.qkv.bias'] = torch.tensor(b_qkv).reshape(-1)
state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.qkv.bias'] = torch.tensor(b_qkv).reshape(-1)
# Attention proj
w_proj = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['proj_kernel']
w_proj = torch.tensor(w_proj).permute(2, 1, 0).flatten(1)
state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.proj.weight'] = w_proj
state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.proj.bias'] = torch.tensor(
state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.proj.weight'] = w_proj
state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.proj.bias'] = torch.tensor(
flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['bias'])
# MLP
for i in range(2):
state_dict[f'ls_transformer_encoder.{level}.{layer}.mlp.fc{i+1}.weight'] = torch.tensor(
state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.weight'] = torch.tensor(
flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['kernel']).permute(1, 0)
state_dict[f'ls_transformer_encoder.{level}.{layer}.mlp.fc{i+1}.bias'] = torch.tensor(
state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.bias'] = torch.tensor(
flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['bias'])
# Block aggregations
for level in range(len(depths)-1):
# Convs
state_dict[f'ls_block_aggregation.{level}.conv.weight'] = torch.tensor(
state_dict[f'block_aggs.{level}.conv.weight'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['Conv_0']['kernel']).permute(3, 2, 0, 1)
state_dict[f'ls_block_aggregation.{level}.conv.bias'] = torch.tensor(
state_dict[f'block_aggs.{level}.conv.bias'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['Conv_0']['bias'])
# Norms
state_dict[f'ls_block_aggregation.{level}.norm.weight'] = torch.tensor(
state_dict[f'block_aggs.{level}.norm.weight'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['scale'])
state_dict[f'ls_block_aggregation.{level}.norm.bias'] = torch.tensor(
state_dict[f'block_aggs.{level}.norm.bias'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['bias'])
# Final norm
@ -102,6 +104,6 @@ def convert_nest(checkpoint_path, arch):
if __name__ == '__main__':
variant = 'base'
variant = sys.argv[1] # base, small, or tiny
state_dict = convert_nest(f'../nested-transformer/checkpoints/nest-{variant[0]}_imagenet', f'nest_{variant}')
torch.save(state_dict, f'jx_nest_{variant}.pth')
torch.save(state_dict, f'/home/alexander/.cache/torch/hub/checkpoints/jx_nest_{variant}.pth')

@ -17,7 +17,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*']
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures

@ -277,7 +277,6 @@ class Nest(nn.Module):
# Patch embedding
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
self.feature_info = [dict(num_chs=embed_dims[0], reduction=patch_size, module='patch_embed')]
self.num_patches = self.patch_embed.num_patches
self.seq_length = self.num_patches // self.num_blocks[0]
@ -291,8 +290,6 @@ class Nest(nn.Module):
self.num_blocks[lix], self.block_size, self.seq_length, num_heads[lix], depths[lix],
embed_dims[lix], mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dpr, norm_layer,
act_layer))
self.feature_info.append(
dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction']*2, module=f'levels.{lix}'))
if lix < self.num_levels - 1:
self.block_aggs.append(BlockAggregation(
embed_dims[lix], embed_dims[lix+1], norm_layer, pad_type=pad_type))
@ -302,8 +299,6 @@ class Nest(nn.Module):
# Final normalization layer
self.norm = norm_layer(embed_dims[-1])
self.feature_info.append(
dict(num_chs=embed_dims[lix], reduction=self.feature_info[-1]['reduction'], module='norm'))
# Classifier
self.global_pool, self.head = create_classifier(
@ -319,7 +314,7 @@ class Nest(nn.Module):
if mode.startswith('jax'):
named_apply(partial(_init_nest_weights, head_bias=head_bias, jax_impl=True), self)
else:
self.apply(_init_nest_weights)
named_apply(_init_nest_weights, self)
@torch.jit.ignore
def no_weight_decay(self):
@ -409,16 +404,14 @@ def checkpoint_filter_fn(state_dict, model):
def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs):
# if kwargs.get('features_only', None):
# raise RuntimeError('features_only not implemented for Vision Transformer models.')
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
default_cfg = default_cfg or default_cfgs[variant]
model = build_model_with_cfg(
Nest, variant, pretrained,
default_cfg=default_cfg,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(
out_indices=tuple(range(kwargs.get('num_levels', 3) + 2)), feature_cls='hook', flatten_sequential=True),
**kwargs)
return model
@ -487,11 +480,4 @@ def jx_nest_tiny(pretrained=False, **kwargs):
model_kwargs = dict(
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs)
model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs)
return model
if __name__ == '__main__':
model = jx_nest_base()
model = torch.jit.script(model)
inp = torch.zeros(8, 3, 224, 224)
print(model.forward_features(inp).shape)
return model
Loading…
Cancel
Save