Overhaul FocalNet implementation

pull/1628/head
Ross Wightman 2 years ago
parent 7266c5c716
commit 848d200767

@ -1,19 +1,32 @@
""" FocalNet
As described in `Focal Modulation Networks` - https://arxiv.org/abs/2203.11926
Significant modifications and refactoring from the original impl at https://github.com/microsoft/FocalNet
This impl is/has:
* fully convolutional, NCHW tensor layout throughout, seemed to have minimal performance impact but more flexible
* re-ordered downsample / layer so that striding always at beginning of layer (stage)
* no input size constraints or input resolution/H/W tracking through the model
* torchscript fixed and a number of quirks cleaned up
* feature extraction support via `features_only=True`
"""
# -------------------------------------------------------- # --------------------------------------------------------
# FocalNets -- Focal Modulation Networks # FocalNets -- Focal Modulation Networks
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# Written by Jianwei Yang (jianwyan@microsoft.com) # Written by Jianwei Yang (jianwyan@microsoft.com)
# -------------------------------------------------------- # --------------------------------------------------------
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function from ._manipulate import named_apply
from ._registry import register_model from ._registry import register_model
__all__ = ['FocalNet'] __all__ = ['FocalNet']
@ -27,9 +40,10 @@ class FocalModulation(nn.Module):
focal_level, focal_level,
focal_factor=2, focal_factor=2,
bias=True, bias=True,
proj_drop=0., use_post_norm=False,
use_postln_in_modulation=False,
normalize_modulator=False, normalize_modulator=False,
proj_drop=0.,
norm_layer=LayerNorm2d,
): ):
super().__init__() super().__init__()
@ -37,69 +51,70 @@ class FocalModulation(nn.Module):
self.focal_window = focal_window self.focal_window = focal_window
self.focal_level = focal_level self.focal_level = focal_level
self.focal_factor = focal_factor self.focal_factor = focal_factor
self.use_postln_in_modulation = use_postln_in_modulation self.use_post_norm = use_post_norm
self.normalize_modulator = normalize_modulator self.normalize_modulator = normalize_modulator
self.input_split = [dim, dim, self.focal_level + 1]
self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=bias) self.f = nn.Conv2d(dim, 2 * dim + (self.focal_level + 1), kernel_size=1, bias=bias)
self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias) self.h = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.act = nn.GELU() self.act = nn.GELU()
self.proj = nn.Linear(dim, dim) self.proj = nn.Conv2d(dim, dim, kernel_size=1)
self.proj_drop = nn.Dropout(proj_drop) self.proj_drop = nn.Dropout(proj_drop)
self.focal_layers = nn.ModuleList() self.focal_layers = nn.ModuleList()
self.kernel_sizes = [] self.kernel_sizes = []
for k in range(self.focal_level): for k in range(self.focal_level):
kernel_size = self.focal_factor * k + self.focal_window kernel_size = self.focal_factor * k + self.focal_window
self.focal_layers.append( self.focal_layers.append(nn.Sequential(
nn.Sequential( nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=dim, padding=kernel_size // 2, bias=False),
nn.Conv2d(
dim, dim, kernel_size=kernel_size, stride=1,
groups=dim, padding=kernel_size // 2, bias=False),
nn.GELU(), nn.GELU(),
) ))
)
self.kernel_sizes.append(kernel_size) self.kernel_sizes.append(kernel_size)
if self.use_postln_in_modulation: self.norm = norm_layer(dim) if self.use_post_norm else nn.Identity()
self.ln = nn.LayerNorm(dim)
def forward(self, x): def forward(self, x):
""" """
Args: Args:
x: input features with shape of (B, H, W, C) x: input features with shape of (B, H, W, C)
""" """
C = x.shape[-1] C = x.shape[1]
# pre linear projection # pre linear projection
x = self.f(x).permute(0, 3, 1, 2).contiguous() x = self.f(x)
q, ctx, self.gates = torch.split(x, (C, C, self.focal_level + 1), 1) q, ctx, gates = torch.split(x, self.input_split, 1)
# context aggreation # context aggreation
ctx_all = 0 ctx_all = 0
for l in range(self.focal_level): for l, focal_layer in enumerate(self.focal_layers):
ctx = self.focal_layers[l](ctx) ctx = focal_layer(ctx)
ctx_all = ctx_all + ctx * self.gates[:, l:l + 1] ctx_all = ctx_all + ctx * gates[:, l:l + 1]
ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) ctx_global = self.act(ctx.mean((2, 3), keepdim=True))
ctx_all = ctx_all + ctx_global * self.gates[:, self.focal_level:] ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:]
# normalize context # normalize context
if self.normalize_modulator: if self.normalize_modulator:
ctx_all = ctx_all / (self.focal_level + 1) ctx_all = ctx_all / (self.focal_level + 1)
# focal modulation # focal modulation
self.modulator = self.h(ctx_all) x_out = q * self.h(ctx_all)
x_out = q * self.modulator x_out = self.norm(x_out)
x_out = x_out.permute(0, 2, 3, 1).contiguous()
if self.use_postln_in_modulation:
x_out = self.ln(x_out)
# post linear porjection # post linear projection
x_out = self.proj(x_out) x_out = self.proj(x_out)
x_out = self.proj_drop(x_out) x_out = self.proj_drop(x_out)
return x_out return x_out
def extra_repr(self) -> str:
return f'dim={self.dim}' class LayerScale2d(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
gamma = self.gamma.view(1, -1, 1, 1)
return x.mul_(gamma) if self.inplace else x * gamma
class FocalNetBlock(nn.Module): class FocalNetBlock(nn.Module):
@ -107,297 +122,238 @@ class FocalNetBlock(nn.Module):
Args: Args:
dim (int): Number of input channels. dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0 proj_drop (float, optional): Dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
focal_level (int): Number of focal levels. focal_level (int): Number of focal levels.
focal_window (int): Focal window size at first focal level focal_window (int): Focal window size at first focal level
layerscale_value (float): Initial layerscale value layerscale_value (float): Initial layerscale value
use_postln (bool): Whether to use layernorm after modulation use_post_norm (bool): Whether to use layernorm after modulation
""" """
def __init__( def __init__(
self, self,
dim, dim,
input_resolution,
mlp_ratio=4., mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
focal_level=1, focal_level=1,
focal_window=3, focal_window=3,
layerscale_value=1e-4, use_post_norm=False,
use_postln=False, use_post_norm_in_modulation=False,
use_postln_in_modulation=False,
normalize_modulator=False, normalize_modulator=False,
layerscale_value=1e-4,
proj_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=LayerNorm2d,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.input_resolution = input_resolution
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.focal_window = focal_window self.focal_window = focal_window
self.focal_level = focal_level self.focal_level = focal_level
self.use_postln = use_postln self.use_post_norm = use_post_norm
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim) if not use_post_norm else nn.Identity()
self.modulation = FocalModulation( self.modulation = FocalModulation(
dim, dim,
proj_drop=drop,
focal_window=focal_window, focal_window=focal_window,
focal_level=self.focal_level, focal_level=self.focal_level,
use_postln_in_modulation=use_postln_in_modulation, use_post_norm=use_post_norm_in_modulation,
normalize_modulator=normalize_modulator, normalize_modulator=normalize_modulator,
proj_drop=proj_drop,
norm_layer=norm_layer,
) )
self.norm1_post = norm_layer(dim) if use_post_norm else nn.Identity()
self.ls1 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) if not use_post_norm else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp( self.mlp = Mlp(
in_features=dim, in_features=dim,
hidden_features=mlp_hidden_dim, hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
drop=drop, drop=proj_drop,
use_conv=True,
) )
self.norm2_post = norm_layer(dim) if use_post_norm else nn.Identity()
self.gamma_1 = 1.0 self.ls2 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity()
self.gamma_2 = 1.0 self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if layerscale_value is not None:
self.gamma_1 = nn.Parameter(layerscale_value * torch.ones(dim))
self.gamma_2 = nn.Parameter(layerscale_value * torch.ones(dim))
self.H = None
self.W = None
def forward(self, x): def forward(self, x):
H, W = self.H, self.W
B, L, C = x.shape
shortcut = x shortcut = x
# Focal Modulation # Focal Modulation
x = x if self.use_postln else self.norm1(x) x = self.norm1(x)
x = x.view(B, H, W, C) x = self.modulation(x)
x = self.modulation(x).view(B, H * W, C) x = self.norm1_post(x)
x = x if not self.use_postln else self.norm1(x) x = shortcut + self.drop_path1(self.ls1(x))
# FFN # FFN
x = shortcut + self.drop_path(self.gamma_1 * x) x = x + self.drop_path2(self.ls2(self.norm2_post(self.mlp(self.norm2(x)))))
x = x + self.drop_path(self.gamma_2 * (self.norm2(self.mlp(x)) if self.use_postln else self.mlp(self.norm2(x))))
return x return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, " \
f"mlp_ratio={self.mlp_ratio}"
class BasicLayer(nn.Module): class BasicLayer(nn.Module):
""" A basic Focal Transformer layer for one stage. """ A basic Focal Transformer layer for one stage.
Args: Args:
dim (int): Number of input channels. dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks. depth (int): Number of blocks.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0 drop (float, optional): Dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None downsample (bool): Downsample layer at start of the layer. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
focal_level (int): Number of focal levels focal_level (int): Number of focal levels
focal_window (int): Focal window size at first focal level focal_window (int): Focal window size at first focal level
layerscale_value (float): Initial layerscale value layerscale_value (float): Initial layerscale value
use_postln (bool): Whether to use layer norm after modulation use_post_norm (bool): Whether to use layer norm after modulation
""" """
def __init__( def __init__(
self, self,
dim, dim,
out_dim, out_dim,
input_resolution,
depth, depth,
mlp_ratio=4., mlp_ratio=4.,
drop=0., downsample=True,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
focal_level=1, focal_level=1,
focal_window=1, focal_window=1,
use_conv_embed=False, use_overlap_down=False,
use_post_norm=False,
use_post_norm_in_modulation=False,
normalize_modulator=False,
layerscale_value=1e-4, layerscale_value=1e-4,
use_postln=False, proj_drop=0.,
use_postln_in_modulation=False, drop_path=0.,
normalize_modulator=False norm_layer=LayerNorm2d,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.input_resolution = input_resolution
self.depth = depth self.depth = depth
self.use_checkpoint = use_checkpoint self.grad_checkpointing = False
if downsample:
self.downsample = Downsample(
in_chs=dim,
out_chs=out_dim,
stride=2,
overlap=use_overlap_down,
norm_layer=norm_layer,
)
else:
self.downsample = nn.Identity()
# build blocks # build blocks
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
FocalNetBlock( FocalNetBlock(
dim=dim, dim=out_dim,
input_resolution=input_resolution,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
focal_level=focal_level, focal_level=focal_level,
focal_window=focal_window, focal_window=focal_window,
layerscale_value=layerscale_value, use_post_norm=use_post_norm,
use_postln=use_postln, use_post_norm_in_modulation=use_post_norm_in_modulation,
use_postln_in_modulation=use_postln_in_modulation,
normalize_modulator=normalize_modulator, normalize_modulator=normalize_modulator,
) layerscale_value=layerscale_value,
for i in range(depth)]) proj_drop=proj_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
if downsample is not None:
self.downsample = downsample(
img_size=input_resolution,
patch_size=2,
in_chans=dim,
embed_dim=out_dim,
use_conv_embed=use_conv_embed,
norm_layer=norm_layer, norm_layer=norm_layer,
is_stem=False
) )
else: for i in range(depth)])
self.downsample = None
def forward(self, x, H, W): def forward(self, x):
x = self.downsample(x)
for blk in self.blocks: for blk in self.blocks:
blk.H, blk.W = H, W if self.grad_checkpointing and not torch.jit.is_scripting():
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x) x = checkpoint.checkpoint(blk, x)
else: else:
x = blk(x) x = blk(x)
return x
if self.downsample is not None:
x = x.transpose(1, 2).reshape(x.shape[0], -1, H, W)
x, Ho, Wo = self.downsample(x)
else:
Ho, Wo = H, W
return x, Ho, Wo
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
class Downsample(nn.Module):
r"""
Args: Args:
img_size (int): Image size. Default: 224. in_chs (int): Number of input image channels
patch_size (int): Patch token size. Default: 4. out_chs (int): Number of linear projection output channels
in_chans (int): Number of input image channels. Default: 3. stride (int): Downsample stride. Default: 4.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None norm_layer (nn.Module, optional): Normalization layer. Default: None
""" """
def __init__( def __init__(
self, self,
img_size=(224, 224), in_chs,
patch_size=4, out_chs,
in_chans=3, stride=4,
embed_dim=96, overlap=False,
use_conv_embed=False,
norm_layer=None, norm_layer=None,
is_stem=False,
): ):
super().__init__() super().__init__()
patch_size = to_2tuple(patch_size) self.stride = stride
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
padding = 0 padding = 0
kernel_size = patch_size kernel_size = stride
stride = patch_size if overlap:
if use_conv_embed: assert stride in (2, 4)
# if we choose to use conv embedding, then we treat the stem and non-stem differently if stride == 4:
if is_stem: kernel_size, padding = 7, 2
kernel_size = 7 elif stride == 2:
padding = 2 kernel_size, padding = 3, 1
stride = 4 self.proj = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding)
else: self.norm = norm_layer(out_chs) if norm_layer is not None else nn.Identity()
kernel_size = 3
padding = 1
stride = 2
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x) x = self.proj(x)
H, W = x.shape[2:]
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x) x = self.norm(x)
return x, H, W return x
class FocalNet(nn.Module): class FocalNet(nn.Module):
r""" Focal Modulation Networks (FocalNets) r""" Focal Modulation Networks (FocalNets)
Args: Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3 in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000 num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96 embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Focal Transformer layer. depths (tuple(int)): Depth of each Focal Transformer layer.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
drop_rate (float): Dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level.
Default: [1, 1, 1, 1] Default: [1, 1, 1, 1]
focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1] focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1]
use_conv_embed (bool): Whether to use convolutional embedding. use_overlap_down (bool): Whether to use convolutional embedding.
use_post_norm (bool): Whether to use layernorm after modulation (it helps stablize training of large models)
layerscale_value (float): Value for layer scale. Default: 1e-4 layerscale_value (float): Value for layer scale. Default: 1e-4
use_postln (bool): Whether to use layernorm after modulation (it helps stablize training of large models) drop_rate (float): Dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
""" """
def __init__( def __init__(
self, self,
img_size=224,
patch_size=4,
in_chans=3, in_chans=3,
num_classes=1000, num_classes=1000,
global_pool='avg',
embed_dim=96, embed_dim=96,
depths=[2, 2, 6, 2], depths=(2, 2, 6, 2),
mlp_ratio=4., mlp_ratio=4.,
focal_levels=(2, 2, 2, 2),
focal_windows=(3, 3, 3, 3),
use_overlap_down=False,
use_post_norm=False,
use_post_norm_in_modulation=False,
normalize_modulator=False,
head_hidden_size=None,
head_init_scale=1.0,
layerscale_value=None,
drop_rate=0., drop_rate=0.,
proj_drop_rate=0.,
drop_path_rate=0.1, drop_path_rate=0.1,
norm_layer=nn.LayerNorm, norm_layer=partial(LayerNorm2d, eps=1e-5),
patch_norm=True,
use_checkpoint=False,
focal_levels=[2, 2, 2, 2],
focal_windows=[3, 3, 3, 3],
use_conv_embed=False,
layerscale_value=None,
use_postln=False,
use_postln_in_modulation=False,
normalize_modulator=False,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
@ -407,129 +363,186 @@ class FocalNet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.num_features = embed_dim[-1] self.num_features = embed_dim[-1]
self.mlp_ratio = mlp_ratio self.feature_info = []
# split image into patches using either non-overlapped embedding or overlapped embedding self.stem = Downsample(
self.patch_embed = PatchEmbed( in_chs=in_chans,
img_size=to_2tuple(img_size), out_chs=embed_dim[0],
patch_size=patch_size, overlap=use_overlap_down,
in_chans=in_chans, norm_layer=norm_layer,
embed_dim=embed_dim[0],
use_conv_embed=use_conv_embed,
norm_layer=norm_layer if self.patch_norm else None,
is_stem=True
) )
in_dim = embed_dim[0]
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
layers = []
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers): for i_layer in range(self.num_layers):
out_dim = embed_dim[i_layer]
layer = BasicLayer( layer = BasicLayer(
dim=embed_dim[i_layer], dim=in_dim,
out_dim=embed_dim[i_layer + 1] if (i_layer < self.num_layers - 1) else None, out_dim=out_dim,
input_resolution=(
patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer], depth=depths[i_layer],
mlp_ratio=self.mlp_ratio, mlp_ratio=mlp_ratio,
drop=drop_rate, downsample=i_layer > 0,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
focal_level=focal_levels[i_layer], focal_level=focal_levels[i_layer],
focal_window=focal_windows[i_layer], focal_window=focal_windows[i_layer],
use_conv_embed=use_conv_embed, use_overlap_down=use_overlap_down,
use_checkpoint=use_checkpoint, use_post_norm=use_post_norm,
use_post_norm_in_modulation=use_post_norm_in_modulation,
normalize_modulator=normalize_modulator,
layerscale_value=layerscale_value, layerscale_value=layerscale_value,
use_postln=use_postln, proj_drop=proj_drop_rate,
use_postln_in_modulation=use_postln_in_modulation, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
normalize_modulator=normalize_modulator norm_layer=norm_layer,
) )
self.layers.append(layer) in_dim = out_dim
layers += [layer]
self.feature_info += [dict(num_chs=out_dim, reduction=4 * 2 ** i_layer, module=f'layers.{i_layer}')]
self.layers = nn.Sequential(*layers)
if head_hidden_size:
self.norm = nn.Identity()
self.head = NormMlpClassifierHead(
self.num_features,
num_classes,
hidden_size=head_hidden_size,
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=norm_layer,
)
else:
self.norm = norm_layer(self.num_features) self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = ClassifierHead(
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.num_features,
num_classes,
self.apply(self._init_weights) pool_type=global_pool,
drop_rate=drop_rate
)
def _init_weights(self, m): named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {''} return {''}
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for l in self.layers:
l.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
return self.classifier.fc
def reset_classifier(self, num_classes, global_pool=None):
self.classifier.reset(num_classes, global_pool=global_pool)
def forward_features(self, x): def forward_features(self, x):
x, H, W = self.patch_embed(x) x = self.stem(x)
x = self.pos_drop(x) x = self.layers(x)
x = self.norm(x)
for layer in self.layers:
x, H, W = layer(x, H, W)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = self.head(x) x = self.forward_head(x)
return x return x
def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if name and 'head.fc' in name:
module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale)
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
return { return {
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'first_conv': 'stem.proj', 'classifier': 'head.fc',
**kwargs **kwargs
} }
default_cfgs = { default_cfgs = {
"focalnet_tiny_srf": _cfg(), "focalnet_tiny_srf": _cfg(
"focalnet_small_srf": _cfg(url="https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_srf.pth"), url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_srf.pth'),
"focalnet_base_srf": _cfg(), "focalnet_small_srf": _cfg(
"focalnet_tiny_lrf": _cfg(), url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_srf.pth'),
"focalnet_small_lrf": _cfg(), "focalnet_base_srf": _cfg(
"focalnet_base_lrf": _cfg(url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth'), url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_srf.pth'),
"focalnet_large_fl3": _cfg(url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth', input_size=(3, 384, 384), num_classes=21842), "focalnet_tiny_lrf": _cfg(
"focalnet_large_fl4": _cfg(url="https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth", input_size=(3, 384, 384), num_classes=21842), url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_lrf.pth'),
"focalnet_small_lrf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_lrf.pth'),
"focalnet_base_lrf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth'),
"focalnet_large_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
"focalnet_large_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
"focalnet_huge_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224.pth',
num_classes=0),
"focalnet_huge_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224_fl4.pth',
num_classes=0),
} }
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model: FocalNet):
if 'stem.proj.weight' in state_dict:
return
import re
out_dict = {} out_dict = {}
if 'model' in state_dict: if 'model' in state_dict:
# For deit models
state_dict = state_dict['model'] state_dict = state_dict['model']
dest_dict = model.state_dict()
for k, v in state_dict.items(): for k, v in state_dict.items():
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]): k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
continue # skip buffers that should not be persistent k = k.replace('patch_embed', 'stem')
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
if 'norm' in k and k not in dest_dict:
k = re.sub(r'norm([0-9])', r'norm\1_post', k)
k = k.replace('ln.', 'norm.')
k = k.replace('head', 'head.fc')
if dest_dict[k].shape != v.shape:
v = v.reshape(dest_dict[k].shape)
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict
def _create_focalnet(variant, pretrained=False, **kwargs): def _create_focalnet(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg( model = build_model_with_cfg(
FocalNet, variant, pretrained, FocalNet, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs) **kwargs)
return model return model
@ -569,10 +582,13 @@ def focalnet_base_lrf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs) model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs)
return _create_focalnet('focalnet_base_lrf', pretrained=pretrained, **model_kwargs) return _create_focalnet('focalnet_base_lrf', pretrained=pretrained, **model_kwargs)
# FocalNet large+ models # FocalNet large+ models
@register_model @register_model
def focalnet_large_fl3(pretrained=False, **kwargs): def focalnet_large_fl3(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[3, 3, 3, 3], **kwargs) model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4,
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_large_fl3', pretrained=pretrained, **model_kwargs) return _create_focalnet('focalnet_large_fl3', pretrained=pretrained, **model_kwargs)
@ -580,37 +596,38 @@ def focalnet_large_fl3(pretrained=False, **kwargs):
def focalnet_large_fl4(pretrained=False, **kwargs): def focalnet_large_fl4(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[4, 4, 4, 4], depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[4, 4, 4, 4],
use_conv_embed=True, layerscale_value=1e-4, **kwargs) use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_large_fl4', pretrained=pretrained, **model_kwargs) return _create_focalnet('focalnet_large_fl4', pretrained=pretrained, **model_kwargs)
#
# @register_model
# def focalnet_large_fl4(pretrained=False, **kwargs):
# model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[4, 4, 4, 4], **kwargs)
# return _create_focalnet('focalnet_large_fl4', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def focalnet_xlarge_fl3(pretrained=False, **kwargs): def focalnet_xlarge_fl3(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[3, 3, 3, 3], **kwargs) model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4,
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_xlarge_fl3', pretrained=pretrained, **model_kwargs) return _create_focalnet('focalnet_xlarge_fl3', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def focalnet_xlarge_fl4(pretrained=False, **kwargs): def focalnet_xlarge_fl4(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[4, 4, 4, 4], **kwargs) model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[4, 4, 4, 4],
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_xlarge_fl4', pretrained=pretrained, **model_kwargs) return _create_focalnet('focalnet_xlarge_fl4', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def focalnet_huge_fl3(pretrained=False, **kwargs): def focalnet_huge_fl3(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[3, 3, 3, 3], **kwargs) model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[3, 3, 3, 3], focal_windows=[3] * 4,
use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_huge_fl3', pretrained=pretrained, **model_kwargs) return _create_focalnet('focalnet_huge_fl3', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def focalnet_huge_fl4(pretrained=False, **kwargs): def focalnet_huge_fl4(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[4, 4, 4, 4], **kwargs) model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[4, 4, 4, 4],
use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_huge_fl4', pretrained=pretrained, **model_kwargs) return _create_focalnet('focalnet_huge_fl4', pretrained=pretrained, **model_kwargs)

Loading…
Cancel
Save