diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 79778ab1..17faba53 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -170,6 +170,11 @@ default_cfgs = { '/vit_base_patch16_224_1k_miil_84_4.pth', mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', ), + + # experimental + 'vit_small_patch16_36x1_224': _cfg(url=''), + 'vit_small_patch16_18x2_224': _cfg(url=''), + 'vit_base_patch16_18x2_224': _cfg(url=''), } @@ -201,28 +206,81 @@ class Attention(nn.Module): return x +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + class Block(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): - x = x + self.drop_path1(self.attn(self.norm1(x))) - x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x +class ParallelBlock(nn.Module): + + def __init__( + self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.num_parallel = num_parallel + self.attns = nn.ModuleList() + self.ffns = nn.ModuleList() + for _ in range(num_parallel): + self.attns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + self.ffns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + + def _forward_jit(self, x): + x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) + x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) + return x + + @torch.jit.ignore + def _forward(self, x): + x = x + sum(attn(x) for attn in self.attns) + x = x + sum(ffn(x) for ffn in self.ffns) + return x + + def forward(self, x): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return self._forward_jit(x) + else: + return self._forward(x) + + class VisionTransformer(nn.Module): """ Vision Transformer @@ -233,8 +291,8 @@ class VisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', - embed_layer=PatchEmbed, norm_layer=None, act_layer=None): + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, + embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): """ Args: img_size (int, tuple): input image size @@ -248,10 +306,11 @@ class VisionTransformer(nn.Module): mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set - weight_init: (str): weight init scheme drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate + weight_init: (str): weight init scheme + init_values: (float): layer-scale init values embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer @@ -277,9 +336,9 @@ class VisionTransformer(nn.Module): dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ - Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, - attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + block_fn( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth)]) use_fc_norm = self.global_pool == 'avg' self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() @@ -941,3 +1000,37 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs): model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def vit_small_patch16_36x1_224(pretrained=False, **kwargs): + """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs) + model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_18x2_224(pretrained=False, **kwargs): + """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs) + model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_18x2_224(pretrained=False, **kwargs): + """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) + model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) + return model