You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/gcvit.py

593 lines
21 KiB

""" Global Context ViT
From scratch implementation of GCViT in the style of timm swin_transformer_v2_cr.py
Global Context Vision Transformers -https://arxiv.org/abs/2206.09959
@article{hatamizadeh2022global,
title={Global Context Vision Transformers},
author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
journal={arXiv preprint arXiv:2206.09959},
year={2022}
}
Free of any code related to NVIDIA GCVit impl at https://github.com/NVlabs/GCVit.
The license for this code release is Apache 2.0 with no commercial restrictions.
However, weight files adapted from NVIDIA GCVit impl ARE under a non-commercial share-alike license
(https://creativecommons.org/licenses/by-nc-sa/4.0/) until I have a chance to train new ones...
Hacked together by / Copyright 2022, Ross Wightman
"""
import math
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply
from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\
get_attn, get_act_layer, get_norm_layer, _assert
from .registry import register_model
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location
__all__ = ['GlobalContextVit']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
'fixed_input_size': True,
**kwargs
}
default_cfgs = {
'gcvit_xxtiny': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xxtiny_224_nvidia-d1d86009.pth'),
'gcvit_xtiny': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xtiny_224_nvidia-274b92b7.pth'),
'gcvit_tiny': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_tiny_224_nvidia-ac783954.pth'),
'gcvit_small': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_small_224_nvidia-4e98afa2.pth'),
'gcvit_base': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_base_224_nvidia-f009139b.pth'),
}
class MbConvBlock(nn.Module):
""" A depthwise separable / fused mbconv style residual block with SE, `no norm.
"""
def __init__(
self,
in_chs,
out_chs=None,
expand_ratio=1.0,
attn_layer='se',
bias=False,
act_layer=nn.GELU,
):
super().__init__()
attn_kwargs = dict(act_layer=act_layer)
if isinstance(attn_layer, str) and attn_layer == 'se' or attn_layer == 'eca':
attn_kwargs['rd_ratio'] = 0.25
attn_kwargs['bias'] = False
attn_layer = get_attn(attn_layer)
out_chs = out_chs or in_chs
mid_chs = int(expand_ratio * in_chs)
self.conv_dw = nn.Conv2d(in_chs, mid_chs, 3, 1, 1, groups=in_chs, bias=bias)
self.act = act_layer()
self.se = attn_layer(mid_chs, **attn_kwargs)
self.conv_pw = nn.Conv2d(mid_chs, out_chs, 1, 1, 0, bias=bias)
def forward(self, x):
shortcut = x
x = self.conv_dw(x)
x = self.act(x)
x = self.se(x)
x = self.conv_pw(x)
x = x + shortcut
return x
class Downsample2d(nn.Module):
def __init__(
self,
dim,
dim_out=None,
reduction='conv',
act_layer=nn.GELU,
norm_layer=LayerNorm2d, # NOTE in NCHW
):
super().__init__()
dim_out = dim_out or dim
self.norm1 = norm_layer(dim) if norm_layer is not None else nn.Identity()
self.conv_block = MbConvBlock(dim, act_layer=act_layer)
assert reduction in ('conv', 'max', 'avg')
if reduction == 'conv':
self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False)
elif reduction == 'max':
assert dim == dim_out
self.reduction = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else:
assert dim == dim_out
self.reduction = nn.AvgPool2d(kernel_size=2)
self.norm2 = norm_layer(dim_out) if norm_layer is not None else nn.Identity()
def forward(self, x):
x = self.norm1(x)
x = self.conv_block(x)
x = self.reduction(x)
x = self.norm2(x)
return x
class FeatureBlock(nn.Module):
def __init__(
self,
dim,
levels=0,
reduction='max',
act_layer=nn.GELU,
):
super().__init__()
reductions = levels
levels = max(1, levels)
if reduction == 'avg':
pool_fn = partial(nn.AvgPool2d, kernel_size=2)
else:
pool_fn = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
self.blocks = nn.Sequential()
for i in range(levels):
self.blocks.add_module(f'conv{i+1}', MbConvBlock(dim, act_layer=act_layer))
if reductions:
self.blocks.add_module(f'pool{i+1}', pool_fn())
reductions -= 1
def forward(self, x):
return self.blocks(x)
class Stem(nn.Module):
def __init__(
self,
in_chs: int = 3,
out_chs: int = 96,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW
):
super().__init__()
self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1)
self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer)
def forward(self, x):
x = self.conv1(x)
x = self.down(x)
return x
class WindowAttentionGlobal(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
window_size: Tuple[int, int],
use_global: bool = True,
qkv_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
):
super().__init__()
window_size = to_2tuple(window_size)
self.window_size = window_size
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.use_global = use_global
self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads)
if self.use_global:
self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
else:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, q_global: Optional[torch.Tensor] = None):
B, N, C = x.shape
if self.use_global and q_global is not None:
_assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal')
kv = self.qkv(x)
kv = kv.reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
q = q_global.repeat(B // q_global.shape[0], 1, 1, 1)
q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
else:
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = self.rel_pos(attn)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def window_partition(x, window_size: Tuple[int, int]):
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class GlobalContextVitBlock(nn.Module):
def __init__(
self,
dim: int,
feat_size: Tuple[int, int],
num_heads: int,
window_size: int = 7,
mlp_ratio: float = 4.,
use_global: bool = True,
qkv_bias: bool = True,
layer_scale: Optional[float] = None,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
attn_layer: Callable = WindowAttentionGlobal,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
):
super().__init__()
feat_size = to_2tuple(feat_size)
window_size = to_2tuple(window_size)
self.window_size = window_size
self.num_windows = int((feat_size[0] // window_size[0]) * (feat_size[1] // window_size[1]))
self.norm1 = norm_layer(dim)
self.attn = attn_layer(
dim,
num_heads=num_heads,
window_size=window_size,
use_global=use_global,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
self.ls1 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
self.ls2 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def _window_attn(self, x, q_global: Optional[torch.Tensor] = None):
B, H, W, C = x.shape
x_win = window_partition(x, self.window_size)
x_win = x_win.view(-1, self.window_size[0] * self.window_size[1], C)
attn_win = self.attn(x_win, q_global)
x = window_reverse(attn_win, self.window_size, (H, W))
return x
def forward(self, x, q_global: Optional[torch.Tensor] = None):
x = x + self.drop_path1(self.ls1(self._window_attn(self.norm1(x), q_global)))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class GlobalContextVitStage(nn.Module):
def __init__(
self,
dim,
depth: int,
num_heads: int,
feat_size: Tuple[int, int],
window_size: Tuple[int, int],
downsample: bool = True,
global_norm: bool = False,
stage_norm: bool = False,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
layer_scale: Optional[float] = None,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: Union[List[float], float] = 0.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
norm_layer_cl: Callable = LayerNorm2d,
):
super().__init__()
if downsample:
self.downsample = Downsample2d(
dim=dim,
dim_out=dim * 2,
norm_layer=norm_layer,
)
dim = dim * 2
feat_size = (feat_size[0] // 2, feat_size[1] // 2)
else:
self.downsample = nn.Identity()
self.feat_size = feat_size
window_size = to_2tuple(window_size)
feat_levels = int(math.log2(min(feat_size) / min(window_size)))
self.global_block = FeatureBlock(dim, feat_levels)
self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity()
self.blocks = nn.ModuleList([
GlobalContextVitBlock(
dim=dim,
num_heads=num_heads,
feat_size=feat_size,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
use_global=(i % 2 != 0),
layer_scale=layer_scale,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
act_layer=act_layer,
norm_layer=norm_layer_cl,
)
for i in range(depth)
])
self.norm = norm_layer_cl(dim) if stage_norm else nn.Identity()
self.dim = dim
self.feat_size = feat_size
self.grad_checkpointing = False
def forward(self, x):
# input NCHW, downsample & global block are 2d conv + pooling
x = self.downsample(x)
global_query = self.global_block(x)
# reshape NCHW --> NHWC for transformer blocks
x = x.permute(0, 2, 3, 1)
global_query = self.global_norm(global_query.permute(0, 2, 3, 1))
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x, global_query)
x = self.norm(x)
x = x.permute(0, 3, 1, 2).contiguous() # back to NCHW
return x
class GlobalContextVit(nn.Module):
def __init__(
self,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
img_size: Tuple[int, int] = 224,
window_ratio: Tuple[int, ...] = (32, 32, 16, 32),
window_size: Tuple[int, ...] = None,
embed_dim: int = 64,
depths: Tuple[int, ...] = (3, 4, 19, 5),
num_heads: Tuple[int, ...] = (2, 4, 8, 16),
mlp_ratio: float = 3.0,
qkv_bias: bool = True,
layer_scale: Optional[float] = None,
drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init='',
act_layer: str = 'gelu',
norm_layer: str = 'layernorm2d',
norm_layer_cl: str = 'layernorm',
norm_eps: float = 1e-5,
):
super().__init__()
act_layer = get_act_layer(act_layer)
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
img_size = to_2tuple(img_size)
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
self.global_pool = global_pool
self.num_classes = num_classes
self.drop_rate = drop_rate
num_stages = len(depths)
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
if window_size is not None:
window_size = to_ntuple(num_stages)(window_size)
else:
assert window_ratio is not None
window_size = tuple([(img_size[0] // r, img_size[1] // r) for r in to_ntuple(num_stages)(window_ratio)])
self.stem = Stem(
in_chs=in_chans,
out_chs=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer
)
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for i in range(num_stages):
last_stage = i == num_stages - 1
stage_scale = 2 ** max(i - 1, 0)
stages.append(GlobalContextVitStage(
dim=embed_dim * stage_scale,
depth=depths[i],
num_heads=num_heads[i],
feat_size=(feat_size[0] // stage_scale, feat_size[1] // stage_scale),
window_size=window_size[i],
downsample=i != 0,
stage_norm=last_stage,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
layer_scale=layer_scale,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
act_layer=act_layer,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
))
self.stages = nn.Sequential(*stages)
# Classifier head
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
if weight_init:
named_apply(partial(self._init_weights, scheme=weight_init), self)
def _init_weights(self, module, name, scheme='vit'):
# note Conv2d left as default init
if scheme == 'vit':
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
else:
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
@torch.jit.ignore
def no_weight_decay(self):
return {
k for k, _ in self.named_parameters()
if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem', # stem and embed
blocks=r'^stages\.(\d+)'
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is None:
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.stem(x)
x = self.stages(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_gcvit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs)
return model
@register_model
def gcvit_xxtiny(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(2, 2, 6, 2),
num_heads=(2, 4, 8, 16),
**kwargs)
return _create_gcvit('gcvit_xxtiny', pretrained=pretrained, **model_kwargs)
@register_model
def gcvit_xtiny(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 6, 5),
num_heads=(2, 4, 8, 16),
**kwargs)
return _create_gcvit('gcvit_xtiny', pretrained=pretrained, **model_kwargs)
@register_model
def gcvit_tiny(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 19, 5),
num_heads=(2, 4, 8, 16),
**kwargs)
return _create_gcvit('gcvit_tiny', pretrained=pretrained, **model_kwargs)
@register_model
def gcvit_small(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 19, 5),
num_heads=(3, 6, 12, 24),
embed_dim=96,
mlp_ratio=2,
layer_scale=1e-5,
**kwargs)
return _create_gcvit('gcvit_small', pretrained=pretrained, **model_kwargs)
@register_model
def gcvit_base(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 19, 5),
num_heads=(4, 8, 16, 32),
embed_dim=128,
mlp_ratio=2,
layer_scale=1e-5,
**kwargs)
return _create_gcvit('gcvit_base', pretrained=pretrained, **model_kwargs)