""" Pyramid Vision Transformer v2 @misc{wang2021pvtv2, title={PVTv2: Improved Baselines with Pyramid Vision Transformer}, author={Wenhai Wang and Enze Xie and Xiang Li and Deng-Ping Fan and Kaitao Song and Ding Liang and Tong Lu and Ping Luo and Ling Shao}, year={2021}, eprint={2106.13797}, archivePrefix={arXiv}, primaryClass={cs.CV} } Based on Apache 2.0 licensed code at https://github.com/whai362/PVT Modifications and timm support by / Copyright 2022, Ross Wightman """ import math from functools import partial from typing import Tuple, List, Callable, Union import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ from ._builder import build_model_with_cfg from ._registry import register_model __all__ = ['PyramidVisionTransformerV2'] def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False, **kwargs } default_cfgs = { 'pvt_v2_b0': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b0.pth'), 'pvt_v2_b1': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b1.pth'), 'pvt_v2_b2': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2.pth'), 'pvt_v2_b3': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b3.pth'), 'pvt_v2_b4': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b4.pth'), 'pvt_v2_b5': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b5.pth'), 'pvt_v2_b2_li': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2_li.pth') } class MlpWithDepthwiseConv(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., extra_relu=False): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.relu = nn.ReLU() if extra_relu else nn.Identity() self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x, feat_size: List[int]): x = self.fc1(x) B, N, C = x.shape x = x.transpose(1, 2).view(B, C, feat_size[0], feat_size[1]) x = self.relu(x) x = self.dwconv(x) x = x.flatten(2).transpose(1, 2) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, sr_ratio=1, linear_attn=False, qkv_bias=True, attn_drop=0., proj_drop=0. ): super().__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) if not linear_attn: self.pool = None if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = nn.LayerNorm(dim) else: self.sr = None self.norm = None self.act = None else: self.pool = nn.AdaptiveAvgPool2d(7) self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) self.norm = nn.LayerNorm(dim) self.act = nn.GELU() def forward(self, x, feat_size: List[int]): B, N, C = x.shape H, W = feat_size q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) if self.pool is not None: x_ = x.permute(0, 2, 1).reshape(B, C, H, W) x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) x_ = self.norm(x_) x_ = self.act(x_) kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) else: if self.sr is not None: x_ = x.permute(0, 2, 1).reshape(B, C, H, W) x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) x_ = self.norm(x_) kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) else: kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) k, v = kv.unbind(0) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4., sr_ratio=1, linear_attn=False, qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, sr_ratio=sr_ratio, linear_attn=linear_attn, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = MlpWithDepthwiseConv( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, extra_relu=linear_attn ) def forward(self, x, feat_size: List[int]): x = x + self.drop_path(self.attn(self.norm1(x), feat_size)) x = x + self.drop_path(self.mlp(self.norm2(x), feat_size)) return x class OverlapPatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768): super().__init__() patch_size = to_2tuple(patch_size) assert max(patch_size) > stride, "Set larger patch_size than stride" self.patch_size = patch_size self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): x = self.proj(x) feat_size = x.shape[-2:] x = x.flatten(2).transpose(1, 2) x = self.norm(x) return x, feat_size class PyramidVisionTransformerStage(nn.Module): def __init__( self, dim: int, dim_out: int, depth: int, downsample: bool = True, num_heads: int = 8, sr_ratio: int = 1, linear_attn: bool = False, mlp_ratio: float = 4.0, qkv_bias: bool = True, drop: float = 0., attn_drop: float = 0., drop_path: Union[List[float], float] = 0.0, norm_layer: Callable = nn.LayerNorm, ): super().__init__() self.grad_checkpointing = False if downsample: self.downsample = OverlapPatchEmbed( patch_size=3, stride=2, in_chans=dim, embed_dim=dim_out) else: assert dim == dim_out self.downsample = None self.blocks = nn.ModuleList([Block( dim=dim_out, num_heads=num_heads, sr_ratio=sr_ratio, linear_attn=linear_attn, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, ) for i in range(depth)]) self.norm = norm_layer(dim_out) def forward(self, x, feat_size: List[int]) -> Tuple[torch.Tensor, List[int]]: if self.downsample is not None: x, feat_size = self.downsample(x) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint.checkpoint(blk, x, feat_size) else: x = blk(x, feat_size) x = self.norm(x) x = x.reshape(x.shape[0], feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2).contiguous() return x, feat_size class PyramidVisionTransformerV2(nn.Module): def __init__( self, img_size=None, in_chans=3, num_classes=1000, global_pool='avg', depths=(3, 4, 6, 3), embed_dims=(64, 128, 256, 512), num_heads=(1, 2, 4, 8), sr_ratios=(8, 4, 2, 1), mlp_ratios=(8., 8., 4., 4.), qkv_bias=True, linear=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, ): super().__init__() self.num_classes = num_classes assert global_pool in ('avg', '') self.global_pool = global_pool self.depths = depths num_stages = len(depths) mlp_ratios = to_ntuple(num_stages)(mlp_ratios) num_heads = to_ntuple(num_stages)(num_heads) sr_ratios = to_ntuple(num_stages)(sr_ratios) assert(len(embed_dims)) == num_stages self.patch_embed = OverlapPatchEmbed( patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0]) dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] cur = 0 prev_dim = embed_dims[0] self.stages = nn.ModuleList() for i in range(num_stages): self.stages.append(PyramidVisionTransformerStage( dim=prev_dim, dim_out=embed_dims[i], depth=depths[i], downsample=i > 0, num_heads=num_heads[i], sr_ratio=sr_ratios[i], mlp_ratio=mlp_ratios[i], linear_attn=linear, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer )) prev_dim = embed_dims[i] cur += depths[i] # classification head self.num_features = embed_dims[-1] self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def freeze_patch_emb(self): self.patch_embed.requires_grad = False @torch.jit.ignore def no_weight_decay(self): return {} @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict( stem=r'^patch_embed', # stem and embed blocks=r'^stages\.(\d+)' ) return matcher @torch.jit.ignore def set_grad_checkpointing(self, enable=True): for s in self.stages: s.grad_checkpointing = enable def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('avg', '') self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x, feat_size = self.patch_embed(x) for stage in self.stages: x, feat_size = stage(x, feat_size=feat_size) return x def forward_head(self, x, pre_logits: bool = False): if self.global_pool: x = x.mean(dim=(-1, -2)) return x if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) return x def _checkpoint_filter_fn(state_dict, model): """ Remap original checkpoints -> timm """ if 'patch_embed.proj.weight' in state_dict: return state_dict # non-original checkpoint, no remapping needed out_dict = {} import re for k, v in state_dict.items(): if k.startswith('patch_embed'): k = k.replace('patch_embed1', 'patch_embed') k = k.replace('patch_embed2', 'stages.1.downsample') k = k.replace('patch_embed3', 'stages.2.downsample') k = k.replace('patch_embed4', 'stages.3.downsample') k = k.replace('dwconv.dwconv', 'dwconv') k = re.sub(r'block(\d+).(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.blocks.{x.group(2)}', k) k = re.sub(r'^norm(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.norm', k) out_dict[k] = v return out_dict def _create_pvt2(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( PyramidVisionTransformerV2, variant, pretrained, pretrained_filter_fn=_checkpoint_filter_fn, **kwargs ) return model @register_model def pvt_v2_b0(pretrained=False, **kwargs): model_kwargs = dict( depths=(2, 2, 2, 2), embed_dims=(32, 64, 160, 256), num_heads=(1, 2, 5, 8), norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return _create_pvt2('pvt_v2_b0', pretrained=pretrained, **model_kwargs) @register_model def pvt_v2_b1(pretrained=False, **kwargs): model_kwargs = dict( depths=(2, 2, 2, 2), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return _create_pvt2('pvt_v2_b1', pretrained=pretrained, **model_kwargs) @register_model def pvt_v2_b2(pretrained=False, **kwargs): model_kwargs = dict( depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return _create_pvt2('pvt_v2_b2', pretrained=pretrained, **model_kwargs) @register_model def pvt_v2_b3(pretrained=False, **kwargs): model_kwargs = dict( depths=(3, 4, 18, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return _create_pvt2('pvt_v2_b3', pretrained=pretrained, **model_kwargs) @register_model def pvt_v2_b4(pretrained=False, **kwargs): model_kwargs = dict( depths=(3, 8, 27, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return _create_pvt2('pvt_v2_b4', pretrained=pretrained, **model_kwargs) @register_model def pvt_v2_b5(pretrained=False, **kwargs): model_kwargs = dict( depths=(3, 6, 40, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), mlp_ratios=(4, 4, 4, 4), norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return _create_pvt2('pvt_v2_b5', pretrained=pretrained, **model_kwargs) @register_model def pvt_v2_b2_li(pretrained=False, **kwargs): model_kwargs = dict( depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), norm_layer=partial(nn.LayerNorm, eps=1e-6), linear=True, **kwargs) return _create_pvt2('pvt_v2_b2_li', pretrained=pretrained, **model_kwargs)