Move aggregation (convpool) for nest into NestLevel, cleanup and enable features_only use. Finalize weight url.

pull/738/head
Ross Wightman 4 years ago
parent 6ae0ac6420
commit 81cd6863c8

@ -23,6 +23,9 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### July 5, 2021
* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare).
### June 23, 2021
* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6)

@ -79,18 +79,18 @@ def convert_nest(checkpoint_path, arch):
state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.bias'] = torch.tensor(
flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['bias'])
# Block aggregations
for level in range(len(depths)-1):
# Block aggregations (ConvPool)
for level in range(1, len(depths)):
# Convs
state_dict[f'block_aggs.{level}.conv.weight'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['Conv_0']['kernel']).permute(3, 2, 0, 1)
state_dict[f'block_aggs.{level}.conv.bias'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['Conv_0']['bias'])
state_dict[f'levels.{level}.pool.conv.weight'] = torch.tensor(
flax_dict[f'ConvPool_{level-1}']['Conv_0']['kernel']).permute(3, 2, 0, 1)
state_dict[f'levels.{level}.pool.conv.bias'] = torch.tensor(
flax_dict[f'ConvPool_{level-1}']['Conv_0']['bias'])
# Norms
state_dict[f'block_aggs.{level}.norm.weight'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['scale'])
state_dict[f'block_aggs.{level}.norm.bias'] = torch.tensor(
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['bias'])
state_dict[f'levels.{level}.pool.norm.weight'] = torch.tensor(
flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['scale'])
state_dict[f'levels.{level}.pool.norm.bias'] = torch.tensor(
flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['bias'])
# Final norm
state_dict[f'norm.weight'] = torch.tensor(flax_dict['LayerNorm_0']['scale'])
@ -105,5 +105,5 @@ def convert_nest(checkpoint_path, arch):
if __name__ == '__main__':
variant = sys.argv[1] # base, small, or tiny
state_dict = convert_nest(f'../nested-transformer/checkpoints/nest-{variant[0]}_imagenet', f'nest_{variant}')
torch.save(state_dict, f'/home/alexander/.cache/torch/hub/checkpoints/jx_nest_{variant}.pth')
state_dict = convert_nest(f'./nest-{variant[0]}_imagenet', f'nest_{variant}')
torch.save(state_dict, f'./jx_nest_{variant}.pth')

@ -16,24 +16,19 @@ Copyright 2021 Alexander Soare
"""
import collections.abc
from functools import partial
import math
import logging
import math
from functools import partial
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
from .layers.helpers import to_ntuple
from .layers.create_conv2d import create_conv2d
from .layers.pool2d_same import create_pool2d
from .vision_transformer import Block
from .layers import create_conv2d, create_pool2d, to_ntuple
from .registry import register_model
from .helpers import build_model_with_cfg, named_apply
from .vision_transformer import resize_pos_embed
_logger = logging.getLogger(__name__)
@ -54,9 +49,12 @@ default_cfgs = {
'nest_base': _cfg(),
'nest_small': _cfg(),
'nest_tiny': _cfg(),
'jx_nest_base': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_base.pth'), # TODO
'jx_nest_small': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_small.pth'), # TODO
'jx_nest_tiny': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_tiny.pth'), # TODO
'jx_nest_base': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth'),
'jx_nest_small': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth'),
'jx_nest_tiny': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth'),
}
@ -96,7 +94,7 @@ class Attention(nn.Module):
return x # (B, T, N, C)
class TransformerLayer(Block):
class TransformerLayer(nn.Module):
"""
This is much like `.vision_transformer.Block` but:
- Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
@ -104,8 +102,7 @@ class TransformerLayer(Block):
"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__(dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm)
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -120,7 +117,7 @@ class TransformerLayer(Block):
return x
class BlockAggregation(nn.Module):
class ConvPool(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer, pad_type=''):
super().__init__()
self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True)
@ -152,8 +149,7 @@ def blockify(x, block_size: int):
grid_height = H // block_size
grid_width = W // block_size
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
x = x.permute(0, 1, 3, 2, 4, 5)
x = x.reshape(B, grid_height * grid_width, -1, C)
x = x.transpose(2, 3).reshape(B, grid_height * grid_width, -1, C)
return x # (B, T, N, C)
@ -165,21 +161,28 @@ def deblockify(x, block_size: int):
"""
B, T, _, C = x.shape
grid_size = int(math.sqrt(T))
x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
x = x.permute(0, 1, 3, 2, 4, 5)
height = width = grid_size * block_size
x = x.reshape(B, height, width, C)
x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
x = x.transpose(2, 3).reshape(B, height, width, C)
return x # (B, H, W, C)
class NestLevel(nn.Module):
""" Single hierarchical level of a Nested Transformer
"""
def __init__(self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rates=[], norm_layer=None, act_layer=None):
def __init__(
self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, prev_embed_dim=None,
mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rates=[],
norm_layer=None, act_layer=None, pad_type=''):
super().__init__()
self.block_size = block_size
self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim))
if prev_embed_dim is not None:
self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type)
else:
self.pool = nn.Identity()
# Transformer encoder
if len(drop_path_rates):
assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers'
@ -194,15 +197,14 @@ class NestLevel(nn.Module):
"""
expects x as (B, C, H, W)
"""
# Switch to channels last for transformer
x = x.permute(0, 2, 3, 1) # (B, H', W', C)
x = self.pool(x)
x = x.permute(0, 2, 3, 1) # (B, H', W', C), switch to channels last for transformer
x = blockify(x, self.block_size) # (B, T, N, C')
x = x + self.pos_embed
x = self.transformer_encoder(x) # (B, T, N, C')
x = deblockify(x, self.block_size) # (B, H', W', C')
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
x = x.permute(0, 3, 1, 2) # (B, C, H', W')
return x
return x.permute(0, 3, 1, 2) # (B, C, H', W')
class Nest(nn.Module):
@ -213,9 +215,9 @@ class Nest(nn.Module):
"""
def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512),
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, pad_type='',
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, weight_init='',
global_pool='avg'):
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None,
pad_type='', weight_init='', global_pool='avg'):
"""
Args:
img_size (int, tuple): input image size
@ -233,6 +235,7 @@ class Nest(nn.Module):
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer for transformer layers
act_layer: (nn.Module): activation layer in MLP of transformer layers
pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME
weight_init: (str): weight init scheme
global_pool: (str): type of pooling operation to apply to final feature map
@ -254,6 +257,7 @@ class Nest(nn.Module):
depths = to_ntuple(num_levels)(depths)
self.num_classes = num_classes
self.num_features = embed_dims[-1]
self.feature_info = []
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.drop_rate = drop_rate
@ -265,60 +269,54 @@ class Nest(nn.Module):
self.patch_size = patch_size
# Number of blocks at each level
self.num_blocks = 4**(np.arange(num_levels)[::-1])
assert (img_size // patch_size) % np.sqrt(self.num_blocks[0]) == 0, \
self.num_blocks = (4 ** torch.arange(num_levels)).flip(0).tolist()
assert (img_size // patch_size) % math.sqrt(self.num_blocks[0]) == 0, \
'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`'
# Block edge size in units of patches
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
# number of blocks along edge of image
self.block_size = int((img_size // patch_size) // np.sqrt(self.num_blocks[0]))
self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0]))
# Patch embedding
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], flatten=False)
self.num_patches = self.patch_embed.num_patches
self.seq_length = self.num_patches // self.num_blocks[0]
# Build up each hierarchical level
self.levels = nn.ModuleList([])
self.block_aggs = nn.ModuleList([])
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
for lix in range(self.num_levels):
dpr = drop_path_rates[sum(depths[:lix]):sum(depths[:lix+1])]
self.levels.append(NestLevel(
self.num_blocks[lix], self.block_size, self.seq_length, num_heads[lix], depths[lix],
embed_dims[lix], mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dpr, norm_layer,
act_layer))
if lix < self.num_levels - 1:
self.block_aggs.append(BlockAggregation(
embed_dims[lix], embed_dims[lix+1], norm_layer, pad_type=pad_type))
else:
# Required for zipped iteration over levels and ls_block_agg together
self.block_aggs.append(nn.Identity())
levels = []
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
prev_dim = None
curr_stride = 4
for i in range(len(self.num_blocks)):
dim = embed_dims[i]
levels.append(NestLevel(
self.num_blocks[i], self.block_size, self.seq_length, num_heads[i], depths[i], dim, prev_dim,
mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dp_rates[i], norm_layer, act_layer, pad_type=pad_type))
self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')]
prev_dim = dim
curr_stride *= 2
self.levels = nn.Sequential(*levels)
# Final normalization layer
self.norm = norm_layer(embed_dims[-1])
# Classifier
self.global_pool, self.head = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
assert mode in ('nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
for level in self.levels:
trunc_normal_(level.pos_embed, std=.02, a=-2, b=2)
if mode.startswith('jax'):
named_apply(partial(_init_nest_weights, head_bias=head_bias, jax_impl=True), self)
else:
named_apply(_init_nest_weights, self)
named_apply(partial(_init_nest_weights, head_bias=head_bias), self)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed'}
return {f'level.{i}.pos_embed' for i in range(len(self.levels))}
def get_classifier(self):
return self.head
@ -333,13 +331,8 @@ class Nest(nn.Module):
"""
B, _, H, W = x.shape
x = self.patch_embed(x)
x = x.reshape(B, H//self.patch_size, W//self.patch_size, -1) # (B, H', W', C')
x = x.permute(0, 3, 1, 2)
# NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead
for level, block_agg in zip(self.levels, self.block_aggs):
x = level(x)
x = block_agg(x)
# Layer norm done over channel dim only
x = self.levels(x)
# Layer norm done over channel dim only (to NHWC and back)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
@ -353,22 +346,19 @@ class Nest(nn.Module):
return self.head(x)
def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.):
""" NesT weight initialization
Can replicate Jax implementation. Otherwise follows vision_transformer.py
"""
if isinstance(module, nn.Linear):
if name.startswith('head'):
if jax_impl:
trunc_normal_(module.weight, std=.02, a=-2, b=2)
else:
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
trunc_normal_(module.weight, std=.02, a=-2, b=2)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif jax_impl and isinstance(module, nn.Conv2d):
elif isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02, a=-2, b=2)
if module.bias is not None:
nn.init.zeros_(module.bias)
@ -404,13 +394,11 @@ def checkpoint_filter_fn(state_dict, model):
def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
default_cfg = default_cfg or default_cfgs[variant]
model = build_model_with_cfg(
Nest, variant, pretrained,
default_cfg=default_cfg,
feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True),
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
@ -422,7 +410,7 @@ def nest_base(pretrained=False, **kwargs):
""" Nest-B @ 224x224
"""
model_kwargs = dict(
embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), drop_path_rate=0.5, **kwargs)
embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
model = _create_nest('nest_base', pretrained=pretrained, **model_kwargs)
return model
@ -431,8 +419,7 @@ def nest_base(pretrained=False, **kwargs):
def nest_small(pretrained=False, **kwargs):
""" Nest-S @ 224x224
"""
model_kwargs = dict(
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), drop_path_rate=0.3, **kwargs)
model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs)
model = _create_nest('nest_small', pretrained=pretrained, **model_kwargs)
return model
@ -441,8 +428,7 @@ def nest_small(pretrained=False, **kwargs):
def nest_tiny(pretrained=False, **kwargs):
""" Nest-T @ 224x224
"""
model_kwargs = dict(
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs)
model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs)
model = _create_nest('nest_tiny', pretrained=pretrained, **model_kwargs)
return model
@ -452,9 +438,7 @@ def jx_nest_base(pretrained=False, **kwargs):
""" Nest-B @ 224x224, Pretrained weights converted from official Jax impl.
"""
kwargs['pad_type'] = 'same'
kwargs['weight_init'] = 'jax'
model_kwargs = dict(
embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), drop_path_rate=0.5, **kwargs)
model_kwargs = dict(embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
model = _create_nest('jx_nest_base', pretrained=pretrained, **model_kwargs)
return model
@ -464,9 +448,7 @@ def jx_nest_small(pretrained=False, **kwargs):
""" Nest-S @ 224x224, Pretrained weights converted from official Jax impl.
"""
kwargs['pad_type'] = 'same'
kwargs['weight_init'] = 'jax'
model_kwargs = dict(
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), drop_path_rate=0.3, **kwargs)
model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs)
model = _create_nest('jx_nest_small', pretrained=pretrained, **model_kwargs)
return model
@ -476,8 +458,6 @@ def jx_nest_tiny(pretrained=False, **kwargs):
""" Nest-T @ 224x224, Pretrained weights converted from official Jax impl.
"""
kwargs['pad_type'] = 'same'
kwargs['weight_init'] = 'jax'
model_kwargs = dict(
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs)
model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs)
model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs)
return model
Loading…
Cancel
Save