Update SwinTransformerV2Cr post-merge, update with grad checkpointing / grad matcher

* weight compat break, activate norm3 for final block of final stage (equivalent to pre-head norm, but while still in BLC shape)
* remove fold/unfold for TPU compat, add commented out roll code for TPU
* add option for end of stage norm in all stages
* allow weight_init to be selected between pytorch default inits and xavier / moco style vit variant
pull/1014/head
Ross Wightman 3 years ago
parent b049a5c5c6
commit fe457c1996

@ -12,6 +12,7 @@ This implementation is experimental and subject to change in manners that will b
GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial. GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial.
* num_heads per stage is not detailed for Huge and Giant model variants * num_heads per stage is not detailed for Huge and Giant model variants
* 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts * 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts
* experiments are ongoing wrt to 'main branch' norm layer use and weight init scheme
Noteworthy additions over official Swin v1: Noteworthy additions over official Swin v1:
* MLP relative position embedding is looking promising and adapts to different image/window sizes * MLP relative position embedding is looking promising and adapts to different image/window sizes
@ -37,7 +38,7 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import DropPath, Mlp, to_2tuple, _assert from .layers import DropPath, Mlp, to_2tuple, _assert
from .registry import register_model from .registry import register_model
from .vision_transformer import checkpoint_filter_fn from .vision_transformer import checkpoint_filter_fn
@ -67,27 +68,29 @@ default_cfgs = {
'swin_v2_cr_tiny_384': _cfg( 'swin_v2_cr_tiny_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_tiny_224': _cfg( 'swin_v2_cr_tiny_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_tiny_ns_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_small_384': _cfg( 'swin_v2_cr_small_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_small_224': _cfg( 'swin_v2_cr_small_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_base_384': _cfg( 'swin_v2_cr_base_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_base_224': _cfg( 'swin_v2_cr_base_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_large_384': _cfg( 'swin_v2_cr_large_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_large_224': _cfg( 'swin_v2_cr_large_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_huge_384': _cfg( 'swin_v2_cr_huge_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_huge_224': _cfg( 'swin_v2_cr_huge_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_giant_384': _cfg( 'swin_v2_cr_giant_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_giant_224': _cfg( 'swin_v2_cr_giant_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
} }
@ -175,7 +178,7 @@ class WindowMultiHeadAttention(nn.Module):
hidden_features=meta_hidden_dim, hidden_features=meta_hidden_dim,
out_features=num_heads, out_features=num_heads,
act_layer=nn.ReLU, act_layer=nn.ReLU,
drop=0. # FIXME should we add stochasticity? drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without?
) )
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads)))
self._make_pair_wise_relative_positions() self._make_pair_wise_relative_positions()
@ -336,7 +339,8 @@ class SwinTransformerBlock(nn.Module):
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity() self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()
# extra norm layer mentioned for Huge/Giant models in V2 paper (FIXME may be in wrong spot?) # Extra main branch norm layer mentioned for Huge/Giant models in V2 paper.
# Also being used as final network norm and optional stage ending norm while still in a C-last format.
self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
self._make_attention_mask() self._make_attention_mask()
@ -392,13 +396,16 @@ class SwinTransformerBlock(nn.Module):
x = x.view(B, H, W, C) x = x.view(B, H, W, C)
# cyclic shift # cyclic shift
if any(self.shift_size): sh, sw = self.shift_size
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) do_shift: bool = any(self.shift_size)
else: if do_shift:
shifted_x = x # FIXME PyTorch XLA needs cat impl, roll not lowered
# x = torch.cat([x[:, sh:], x[:, :sh]], dim=1)
# x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2)
x = torch.roll(x, shifts=(-sh, -sw), dims=(1, 2))
# partition windows # partition windows
x_windows = window_partition(shifted_x, self.window_size) # num_windows * B, window_size, window_size, C x_windows = window_partition(x, self.window_size) # num_windows * B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C) x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
# W-MSA/SW-MSA # W-MSA/SW-MSA
@ -406,13 +413,14 @@ class SwinTransformerBlock(nn.Module):
# merge windows # merge windows
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
shifted_x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
# reverse cyclic shift # reverse cyclic shift
if any(self.shift_size): if do_shift:
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) # FIXME PyTorch XLA needs cat impl, roll not lowered
else: # x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1)
x = shifted_x # x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
x = torch.roll(x, shifts=(sh, sw), dims=(1, 2))
x = x.view(B, L, C) x = x.view(B, L, C)
return x return x
@ -429,7 +437,7 @@ class SwinTransformerBlock(nn.Module):
# NOTE post-norm branches (op -> norm -> drop) # NOTE post-norm branches (op -> norm -> drop)
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
x = x + self.drop_path2(self.norm2(self.mlp(x))) x = x + self.drop_path2(self.norm2(self.mlp(x)))
x = self.norm3(x) # main-branch norm enabled for some blocks (every 6 for Huge/Giant) x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
return x return x
@ -452,8 +460,10 @@ class PatchMerging(nn.Module):
Returns: Returns:
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
""" """
x = bchw_to_bhwc(x).unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2) B, C, H, W = x.shape
x = x.permute(0, 1, 2, 5, 4, 3).flatten(3) # permute maintains compat with ch order in official swin impl # unfold + BCHW -> BHWC together
# ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge
x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3)
x = self.norm(x) x = self.norm(x)
x = bhwc_to_bchw(self.reduction(x)) x = bhwc_to_bchw(self.reduction(x))
return x return x
@ -497,8 +507,8 @@ class SwinTransformerStage(nn.Module):
drop_attn (float): Dropout rate of attention map drop_attn (float): Dropout rate of attention map
drop_path (float): Dropout in main path drop_path (float): Dropout in main path
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
grad_checkpointing (bool): If true checkpointing is utilized
extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks
extra_norm_stage (bool): End each stage with an extra norm layer in main branch
sequential_attn (bool): If true sequential self-attention is performed sequential_attn (bool): If true sequential self-attention is performed
""" """
@ -515,17 +525,23 @@ class SwinTransformerStage(nn.Module):
drop_attn: float = 0.0, drop_attn: float = 0.0,
drop_path: Union[List[float], float] = 0.0, drop_path: Union[List[float], float] = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm, norm_layer: Type[nn.Module] = nn.LayerNorm,
grad_checkpointing: bool = False,
extra_norm_period: int = 0, extra_norm_period: int = 0,
extra_norm_stage: bool = False,
sequential_attn: bool = False, sequential_attn: bool = False,
) -> None: ) -> None:
super(SwinTransformerStage, self).__init__() super(SwinTransformerStage, self).__init__()
self.downscale: bool = downscale self.downscale: bool = downscale
self.grad_checkpointing: bool = grad_checkpointing self.grad_checkpointing: bool = False
self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity() self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity()
def _extra_norm(index):
i = index + 1
if extra_norm_period and i % extra_norm_period == 0:
return True
return i == depth if extra_norm_stage else False
embed_dim = embed_dim * 2 if downscale else embed_dim embed_dim = embed_dim * 2 if downscale else embed_dim
self.blocks = nn.Sequential(*[ self.blocks = nn.Sequential(*[
SwinTransformerBlock( SwinTransformerBlock(
@ -538,7 +554,7 @@ class SwinTransformerStage(nn.Module):
drop=drop, drop=drop,
drop_attn=drop_attn, drop_attn=drop_attn,
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path, drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
extra_norm=not (index + 1) % extra_norm_period if extra_norm_period else False, extra_norm=_extra_norm(index),
sequential_attn=sequential_attn, sequential_attn=sequential_attn,
norm_layer=norm_layer, norm_layer=norm_layer,
) )
@ -600,9 +616,9 @@ class SwinTransformerV2Cr(nn.Module):
attn_drop_rate (float): Dropout rate of attention map. Default: 0.0 attn_drop_rate (float): Dropout rate of attention map. Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default: 0.0 drop_path_rate (float): Stochastic depth rate. Default: 0.0
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
grad_checkpointing (bool): If true checkpointing is utilized. Default: False extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks in stage
extra_norm_stage (bool): End each stage with an extra norm layer in main branch
sequential_attn (bool): If true sequential self-attention is performed. Default: False sequential_attn (bool): If true sequential self-attention is performed. Default: False
use_deformable (bool): If true deformable block is used. Default: False
""" """
def __init__( def __init__(
@ -621,10 +637,11 @@ class SwinTransformerV2Cr(nn.Module):
attn_drop_rate: float = 0.0, attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm, norm_layer: Type[nn.Module] = nn.LayerNorm,
grad_checkpointing: bool = False,
extra_norm_period: int = 0, extra_norm_period: int = 0,
extra_norm_stage: bool = False,
sequential_attn: bool = False, sequential_attn: bool = False,
global_pool: str = 'avg', global_pool: str = 'avg',
weight_init='skip',
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
super(SwinTransformerV2Cr, self).__init__() super(SwinTransformerV2Cr, self).__init__()
@ -638,7 +655,7 @@ class SwinTransformerV2Cr(nn.Module):
self.window_size: int = window_size self.window_size: int = window_size
self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1)) self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
self.patch_embed: nn.Module = PatchEmbed( self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=norm_layer) embed_dim=embed_dim, norm_layer=norm_layer)
patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
@ -659,8 +676,8 @@ class SwinTransformerV2Cr(nn.Module):
drop=drop_rate, drop=drop_rate,
drop_attn=attn_drop_rate, drop_attn=attn_drop_rate,
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
grad_checkpointing=grad_checkpointing,
extra_norm_period=extra_norm_period, extra_norm_period=extra_norm_period,
extra_norm_stage=extra_norm_stage or (index + 1) == len(depths), # last stage ends w/ norm
sequential_attn=sequential_attn, sequential_attn=sequential_attn,
norm_layer=norm_layer, norm_layer=norm_layer,
) )
@ -668,12 +685,12 @@ class SwinTransformerV2Cr(nn.Module):
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
self.global_pool: str = global_pool self.global_pool: str = global_pool
self.head: nn.Module = nn.Linear( self.head = nn.Linear(self.num_features, num_classes) if num_classes else nn.Identity()
in_features=self.num_features, out_features=num_classes) if num_classes else nn.Identity()
# FIXME weight init TBD, PyTorch default init appears to be working well, # current weight init skips custom init and uses pytorch layer defaults, seems to work well
# but differs from usual ViT or Swin init. # FIXME more experiments needed
# named_apply(init_weights, self) if weight_init != 'skip':
named_apply(init_weights, self)
def update_input_size( def update_input_size(
self, self,
@ -704,13 +721,28 @@ class SwinTransformerV2Cr(nn.Module):
new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale), new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale),
) )
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^patch_embed', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore()
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
"""Method returns the classification head of the model. """Method returns the classification head of the model.
Returns: Returns:
head (nn.Module): Current classification head head (nn.Module): Current classification head
""" """
head: nn.Module = self.head return self.head
return head
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
"""Method results the classification head """Method results the classification head
@ -722,8 +754,7 @@ class SwinTransformerV2Cr(nn.Module):
self.num_classes: int = num_classes self.num_classes: int = num_classes
if global_pool is not None: if global_pool is not None:
self.global_pool = global_pool self.global_pool = global_pool
self.head: nn.Module = nn.Linear( self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
in_features=self.num_features, out_features=num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x: torch.Tensor) -> torch.Tensor: def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x) x = self.patch_embed(x)
@ -742,41 +773,28 @@ class SwinTransformerV2Cr(nn.Module):
def init_weights(module: nn.Module, name: str = ''): def init_weights(module: nn.Module, name: str = ''):
# FIXME WIP # FIXME WIP determining if there's a better weight init
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
if 'qkv' in name: if 'qkv' in name:
# treat the weights of Q, K, V separately # treat the weights of Q, K, V separately
val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
nn.init.uniform_(module.weight, -val, val) nn.init.uniform_(module.weight, -val, val)
elif 'head' in name:
nn.init.zeros_(module.weight)
else: else:
nn.init.xavier_uniform_(module.weight) nn.init.xavier_uniform_(module.weight)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs): def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
SwinTransformerV2Cr, SwinTransformerV2Cr, variant, pretrained,
variant,
pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs **kwargs
) )
return model return model
@ -804,6 +822,21 @@ def swin_v2_cr_tiny_224(pretrained=False, **kwargs):
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_tiny_ns_224(pretrained=False, **kwargs):
"""Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms.
** Experimental, may make default if results are improved. **
"""
model_kwargs = dict(
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
extra_norm_stage=True,
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_small_384(pretrained=False, **kwargs): def swin_v2_cr_small_384(pretrained=False, **kwargs):
"""Swin-S V2 CR @ 384x384, trained ImageNet-1k""" """Swin-S V2 CR @ 384x384, trained ImageNet-1k"""

Loading…
Cancel
Save