diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 195e451b..9cb577f6 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -28,6 +28,7 @@ from .mobilenetv3 import * from .mobilevit import * from .nasnet import * from .nest import * +from .nextvit import * from .nfnet import * from .pit import * from .pnasnet import * diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py new file mode 100644 index 00000000..50ff2ce1 --- /dev/null +++ b/timm/models/nextvit.py @@ -0,0 +1,390 @@ +# Copyright (c) ByteDance Inc. All rights reserved. +from functools import partial + +import torch +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.models.layers import DropPath, trunc_normal_ +from timm.models.registry import register_model +from torch import nn +from utils import merge_pre_bn + +NORM_EPS = 1e-5 + + +class ConvBNReLU(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + groups=1): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=1, groups=groups, bias=False) + self.norm = nn.BatchNorm2d(out_channels, eps=NORM_EPS) + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + return x + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class PatchEmbed(nn.Module): + def __init__(self, + in_channels, + out_channels, + stride=1): + super(PatchEmbed, self).__init__() + norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS) + if stride == 2: + self.avgpool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False) + self.norm = norm_layer(out_channels) + elif in_channels != out_channels: + self.avgpool = nn.Identity() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False) + self.norm = norm_layer(out_channels) + else: + self.avgpool = nn.Identity() + self.conv = nn.Identity() + self.norm = nn.Identity() + + def forward(self, x): + return self.norm(self.conv(self.avgpool(x))) + + +class MHCA(nn.Module): + """ + Multi-Head Convolutional Attention + """ + def __init__(self, out_channels, head_dim): + super(MHCA, self).__init__() + norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS) + self.group_conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, + padding=1, groups=out_channels // head_dim, bias=False) + self.norm = norm_layer(out_channels) + self.act = nn.ReLU(inplace=True) + self.projection = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False) + + def forward(self, x): + out = self.group_conv3x3(x) + out = self.norm(out) + out = self.act(out) + out = self.projection(out) + return out + + +class Mlp(nn.Module): + def __init__(self, in_features, out_features=None, mlp_ratio=None, drop=0., bias=True): + super().__init__() + out_features = out_features or in_features + hidden_dim = _make_divisible(in_features * mlp_ratio, 32) + self.conv1 = nn.Conv2d(in_features, hidden_dim, kernel_size=1, bias=bias) + self.act = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(hidden_dim, out_features, kernel_size=1, bias=bias) + self.drop = nn.Dropout(drop) + + def merge_bn(self, pre_norm): + merge_pre_bn(self.conv1, pre_norm) + + def forward(self, x): + x = self.conv1(x) + x = self.act(x) + x = self.drop(x) + x = self.conv2(x) + x = self.drop(x) + return x + + +class NCB(nn.Module): + """ + Next Convolution Block + """ + def __init__(self, in_channels, out_channels, stride=1, path_dropout=0, + drop=0, head_dim=32, mlp_ratio=3): + super(NCB, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS) + assert out_channels % head_dim == 0 + + self.patch_embed = PatchEmbed(in_channels, out_channels, stride) + self.mhca = MHCA(out_channels, head_dim) + self.attention_path_dropout = DropPath(path_dropout) + + self.norm = norm_layer(out_channels) + self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop, bias=True) + self.mlp_path_dropout = DropPath(path_dropout) + self.is_bn_merged = False + + def merge_bn(self): + if not self.is_bn_merged: + self.mlp.merge_bn(self.norm) + self.is_bn_merged = True + + def forward(self, x): + x = self.patch_embed(x) + x = x + self.attention_path_dropout(self.mhca(x)) + if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged: + out = self.norm(x) + else: + out = x + x = x + self.mlp_path_dropout(self.mlp(out)) + return x + + +class E_MHSA(nn.Module): + """ + Efficient Multi-Head Self Attention + """ + def __init__(self, dim, out_dim=None, head_dim=32, qkv_bias=True, qk_scale=None, + attn_drop=0, proj_drop=0., sr_ratio=1): + super().__init__() + self.dim = dim + self.out_dim = out_dim if out_dim is not None else dim + self.num_heads = self.dim // head_dim + self.scale = qk_scale or head_dim ** -0.5 + self.q = nn.Linear(dim, self.dim, bias=qkv_bias) + self.k = nn.Linear(dim, self.dim, bias=qkv_bias) + self.v = nn.Linear(dim, self.dim, bias=qkv_bias) + self.proj = nn.Linear(self.dim, self.out_dim) + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + self.N_ratio = sr_ratio ** 2 + if sr_ratio > 1: + self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio) + self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS) + self.is_bn_merged = False + + def merge_bn(self, pre_bn): + merge_pre_bn(self.q, pre_bn) + if self.sr_ratio > 1: + merge_pre_bn(self.k, pre_bn, self.norm) + merge_pre_bn(self.v, pre_bn, self.norm) + else: + merge_pre_bn(self.k, pre_bn) + merge_pre_bn(self.v, pre_bn) + self.is_bn_merged = True + + def forward(self, x): + B, N, C = x.shape + q = self.q(x) + q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + x_ = x.transpose(1, 2) + x_ = self.sr(x_) + if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged: + x_ = self.norm(x_) + x_ = x_.transpose(1, 2) + k = self.k(x_) + k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1) + v = self.v(x_) + v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3) + else: + k = self.k(x) + k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1) + v = self.v(x) + v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3) + attn = (q @ k) * self.scale + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class NTB(nn.Module): + """ + Next Transformer Block + """ + def __init__( + self, in_channels, out_channels, path_dropout, stride=1, sr_ratio=1, + mlp_ratio=2, head_dim=32, mix_block_ratio=0.75, attn_drop=0, drop=0, + ): + super(NTB, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mix_block_ratio = mix_block_ratio + norm_func = partial(nn.BatchNorm2d, eps=NORM_EPS) + + self.mhsa_out_channels = _make_divisible(int(out_channels * mix_block_ratio), 32) + self.mhca_out_channels = out_channels - self.mhsa_out_channels + + self.patch_embed = PatchEmbed(in_channels, self.mhsa_out_channels, stride) + self.norm1 = norm_func(self.mhsa_out_channels) + self.e_mhsa = E_MHSA(self.mhsa_out_channels, head_dim=head_dim, sr_ratio=sr_ratio, + attn_drop=attn_drop, proj_drop=drop) + self.mhsa_path_dropout = DropPath(path_dropout * mix_block_ratio) + + self.projection = PatchEmbed(self.mhsa_out_channels, self.mhca_out_channels, stride=1) + self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim) + self.mhca_path_dropout = DropPath(path_dropout * (1 - mix_block_ratio)) + + self.norm2 = norm_func(out_channels) + self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop) + self.mlp_path_dropout = DropPath(path_dropout) + + self.is_bn_merged = False + + def merge_bn(self): + if not self.is_bn_merged: + self.e_mhsa.merge_bn(self.norm1) + self.mlp.merge_bn(self.norm2) + self.is_bn_merged = True + + def forward(self, x): + x = self.patch_embed(x) + B, C, H, W = x.shape + if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged: + out = self.norm1(x) + else: + out = x + out = rearrange(out, "b c h w -> b (h w) c") # b n c + out = self.mhsa_path_dropout(self.e_mhsa(out)) + x = x + rearrange(out, "b (h w) c -> b c h w", h=H) + + out = self.projection(x) + out = out + self.mhca_path_dropout(self.mhca(out)) + x = torch.cat([x, out], dim=1) + + if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged: + out = self.norm2(x) + else: + out = x + x = x + self.mlp_path_dropout(self.mlp(out)) + return x + + +class NextViT(nn.Module): + def __init__(self, stem_chs, depths, path_dropout, attn_drop=0, drop=0, num_classes=1000, + strides=[1, 2, 2, 2], sr_ratios=[8, 4, 2, 1], head_dim=32, mix_block_ratio=0.75, + use_checkpoint=False): + super(NextViT, self).__init__() + self.use_checkpoint = use_checkpoint + + self.stage_out_channels = [[96] * (depths[0]), + [192] * (depths[1] - 1) + [256], + [384, 384, 384, 384, 512] * (depths[2] // 5), + [768] * (depths[3] - 1) + [1024]] + + # Next Hybrid Strategy + self.stage_block_types = [[NCB] * depths[0], + [NCB] * (depths[1] - 1) + [NTB], + [NCB, NCB, NCB, NCB, NTB] * (depths[2] // 5), + [NCB] * (depths[3] - 1) + [NTB]] + + self.stem = nn.Sequential( + ConvBNReLU(3, stem_chs[0], kernel_size=3, stride=2), + ConvBNReLU(stem_chs[0], stem_chs[1], kernel_size=3, stride=1), + ConvBNReLU(stem_chs[1], stem_chs[2], kernel_size=3, stride=1), + ConvBNReLU(stem_chs[2], stem_chs[2], kernel_size=3, stride=2), + ) + input_channel = stem_chs[-1] + features = [] + idx = 0 + dpr = [x.item() for x in torch.linspace(0, path_dropout, sum(depths))] # stochastic depth decay rule + for stage_id in range(len(depths)): + numrepeat = depths[stage_id] + output_channels = self.stage_out_channels[stage_id] + block_types = self.stage_block_types[stage_id] + for block_id in range(numrepeat): + if strides[stage_id] == 2 and block_id == 0: + stride = 2 + else: + stride = 1 + output_channel = output_channels[block_id] + block_type = block_types[block_id] + if block_type is NCB: + layer = NCB(input_channel, output_channel, stride=stride, path_dropout=dpr[idx + block_id], + drop=drop, head_dim=head_dim) + features.append(layer) + elif block_type is NTB: + layer = NTB(input_channel, output_channel, path_dropout=dpr[idx + block_id], stride=stride, + sr_ratio=sr_ratios[stage_id], head_dim=head_dim, mix_block_ratio=mix_block_ratio, + attn_drop=attn_drop, drop=drop) + features.append(layer) + input_channel = output_channel + idx += numrepeat + self.features = nn.Sequential(*features) + + self.norm = nn.BatchNorm2d(output_channel, eps=NORM_EPS) + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.proj_head = nn.Sequential( + nn.Linear(output_channel, num_classes), + ) + + self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))] + print('initialize_weights...') + self._initialize_weights() + + def merge_bn(self): + self.eval() + for idx, module in self.named_modules(): + if isinstance(module, NCB) or isinstance(module, NTB): + module.merge_bn() + + def _initialize_weights(self): + for n, m in self.named_modules(): + if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm, nn.BatchNorm1d)): + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=.02) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.stem(x) + for idx, layer in enumerate(self.features): + if self.use_checkpoint: + x = checkpoint.checkpoint(layer, x) + else: + x = layer(x) + x = self.norm(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.proj_head(x) + return x + + +@register_model +def nextvit_small(pretrained=False, pretrained_cfg=None, **kwargs): + model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 10, 3], path_dropout=0.1, **kwargs) + return model + + +@register_model +def nextvit_base(pretrained=False, pretrained_cfg=None, **kwargs): + model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 20, 3], path_dropout=0.2, **kwargs) + return model + + +@register_model +def nextvit_large(pretrained=False, pretrained_cfg=None, **kwargs): + model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 30, 3], path_dropout=0.2, **kwargs) + return model