|
|
|
@ -3,6 +3,8 @@ Convert weights from https://github.com/google-research/nested-transformer
|
|
|
|
|
NOTE: You'll need https://github.com/google/CommonLoopUtils, not included in requirements.txt
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
@ -38,7 +40,7 @@ def convert_nest(checkpoint_path, arch):
|
|
|
|
|
# Positional embeddings
|
|
|
|
|
posemb_keys = [k for k in flax_dict.keys() if k.startswith('PositionEmbedding')]
|
|
|
|
|
for i, k in enumerate(posemb_keys):
|
|
|
|
|
state_dict[f'pos_embed_{i}'] = torch.tensor(flax_dict[k]['pos_embedding'])
|
|
|
|
|
state_dict[f'levels.{i}.pos_embed'] = torch.tensor(flax_dict[k]['pos_embedding'])
|
|
|
|
|
|
|
|
|
|
# Transformer encoders
|
|
|
|
|
depths = arch_depths[arch]
|
|
|
|
@ -47,9 +49,9 @@ def convert_nest(checkpoint_path, arch):
|
|
|
|
|
global_layer_ix = sum(depths[:level]) + layer
|
|
|
|
|
# Norms
|
|
|
|
|
for i in range(2):
|
|
|
|
|
state_dict[f'ls_transformer_encoder.{level}.{layer}.norm{i+1}.weight'] = torch.tensor(
|
|
|
|
|
state_dict[f'levels.{level}.transformer_encoder.{layer}.norm{i+1}.weight'] = torch.tensor(
|
|
|
|
|
flax_dict[f'EncoderNDBlock_{global_layer_ix}'][f'LayerNorm_{i}']['scale'])
|
|
|
|
|
state_dict[f'ls_transformer_encoder.{level}.{layer}.norm{i+1}.bias'] = torch.tensor(
|
|
|
|
|
state_dict[f'levels.{level}.transformer_encoder.{layer}.norm{i+1}.bias'] = torch.tensor(
|
|
|
|
|
flax_dict[f'EncoderNDBlock_{global_layer_ix}'][f'LayerNorm_{i}']['bias'])
|
|
|
|
|
# Attention qkv
|
|
|
|
|
w_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_0']['kernel']
|
|
|
|
@ -57,37 +59,37 @@ def convert_nest(checkpoint_path, arch):
|
|
|
|
|
# Pay attention to dims here (maybe get pen and paper)
|
|
|
|
|
w_kv = np.concatenate(np.split(w_kv, 2, -1), 1)
|
|
|
|
|
w_qkv = np.concatenate([w_q, w_kv], 1)
|
|
|
|
|
state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.qkv.weight'] = torch.tensor(w_qkv).flatten(1).permute(1,0)
|
|
|
|
|
state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.qkv.weight'] = torch.tensor(w_qkv).flatten(1).permute(1,0)
|
|
|
|
|
b_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_0']['bias']
|
|
|
|
|
b_kv = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['DenseGeneral_1']['bias']
|
|
|
|
|
# Pay attention to dims here (maybe get pen and paper)
|
|
|
|
|
b_kv = np.concatenate(np.split(b_kv, 2, -1), 0)
|
|
|
|
|
b_qkv = np.concatenate([b_q, b_kv], 0)
|
|
|
|
|
state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.qkv.bias'] = torch.tensor(b_qkv).reshape(-1)
|
|
|
|
|
state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.qkv.bias'] = torch.tensor(b_qkv).reshape(-1)
|
|
|
|
|
# Attention proj
|
|
|
|
|
w_proj = flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['proj_kernel']
|
|
|
|
|
w_proj = torch.tensor(w_proj).permute(2, 1, 0).flatten(1)
|
|
|
|
|
state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.proj.weight'] = w_proj
|
|
|
|
|
state_dict[f'ls_transformer_encoder.{level}.{layer}.attn.proj.bias'] = torch.tensor(
|
|
|
|
|
state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.proj.weight'] = w_proj
|
|
|
|
|
state_dict[f'levels.{level}.transformer_encoder.{layer}.attn.proj.bias'] = torch.tensor(
|
|
|
|
|
flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MultiHeadAttention_0']['bias'])
|
|
|
|
|
# MLP
|
|
|
|
|
for i in range(2):
|
|
|
|
|
state_dict[f'ls_transformer_encoder.{level}.{layer}.mlp.fc{i+1}.weight'] = torch.tensor(
|
|
|
|
|
state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.weight'] = torch.tensor(
|
|
|
|
|
flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['kernel']).permute(1, 0)
|
|
|
|
|
state_dict[f'ls_transformer_encoder.{level}.{layer}.mlp.fc{i+1}.bias'] = torch.tensor(
|
|
|
|
|
state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.bias'] = torch.tensor(
|
|
|
|
|
flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['bias'])
|
|
|
|
|
|
|
|
|
|
# Block aggregations
|
|
|
|
|
for level in range(len(depths)-1):
|
|
|
|
|
# Convs
|
|
|
|
|
state_dict[f'ls_block_aggregation.{level}.conv.weight'] = torch.tensor(
|
|
|
|
|
state_dict[f'block_aggs.{level}.conv.weight'] = torch.tensor(
|
|
|
|
|
flax_dict[f'ConvPool_{level}']['Conv_0']['kernel']).permute(3, 2, 0, 1)
|
|
|
|
|
state_dict[f'ls_block_aggregation.{level}.conv.bias'] = torch.tensor(
|
|
|
|
|
state_dict[f'block_aggs.{level}.conv.bias'] = torch.tensor(
|
|
|
|
|
flax_dict[f'ConvPool_{level}']['Conv_0']['bias'])
|
|
|
|
|
# Norms
|
|
|
|
|
state_dict[f'ls_block_aggregation.{level}.norm.weight'] = torch.tensor(
|
|
|
|
|
state_dict[f'block_aggs.{level}.norm.weight'] = torch.tensor(
|
|
|
|
|
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['scale'])
|
|
|
|
|
state_dict[f'ls_block_aggregation.{level}.norm.bias'] = torch.tensor(
|
|
|
|
|
state_dict[f'block_aggs.{level}.norm.bias'] = torch.tensor(
|
|
|
|
|
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['bias'])
|
|
|
|
|
|
|
|
|
|
# Final norm
|
|
|
|
@ -102,6 +104,6 @@ def convert_nest(checkpoint_path, arch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
variant = 'base'
|
|
|
|
|
variant = sys.argv[1] # base, small, or tiny
|
|
|
|
|
state_dict = convert_nest(f'../nested-transformer/checkpoints/nest-{variant[0]}_imagenet', f'nest_{variant}')
|
|
|
|
|
torch.save(state_dict, f'jx_nest_{variant}.pth')
|
|
|
|
|
torch.save(state_dict, f'/home/alexander/.cache/torch/hub/checkpoints/jx_nest_{variant}.pth')
|