commit
a16ea1e355
@ -0,0 +1,207 @@
|
||||
import math
|
||||
from typing import List, Tuple, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
|
||||
def pixel_freq_bands(
|
||||
num_bands: int,
|
||||
max_freq: float = 224.,
|
||||
linear_bands: bool = True,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if linear_bands:
|
||||
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
|
||||
else:
|
||||
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
|
||||
return bands * torch.pi
|
||||
|
||||
|
||||
def inv_freq_bands(
|
||||
num_bands: int,
|
||||
temperature: float = 100000.,
|
||||
step: int = 2,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
|
||||
return inv_freq
|
||||
|
||||
|
||||
def build_sincos2d_pos_embed(
|
||||
feat_shape: List[int],
|
||||
dim: int = 64,
|
||||
temperature: float = 10000.,
|
||||
reverse_coord: bool = False,
|
||||
interleave_sin_cos: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
feat_shape:
|
||||
dim:
|
||||
temperature:
|
||||
reverse_coord: stack grid order W, H instead of H, W
|
||||
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
|
||||
dtype:
|
||||
device:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
|
||||
pos_dim = dim // 4
|
||||
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
|
||||
|
||||
if reverse_coord:
|
||||
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
|
||||
grid = torch.stack(
|
||||
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
|
||||
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
|
||||
# FIXME add support for unflattened spatial dim?
|
||||
|
||||
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
|
||||
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
|
||||
return pos_emb
|
||||
|
||||
|
||||
def build_fourier_pos_embed(
|
||||
feat_shape: List[int],
|
||||
bands: Optional[torch.Tensor] = None,
|
||||
num_bands: int = 64,
|
||||
max_res: int = 224,
|
||||
linear_bands: bool = False,
|
||||
include_grid: bool = False,
|
||||
concat_out: bool = True,
|
||||
in_pixels: bool = True,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
if bands is None:
|
||||
if in_pixels:
|
||||
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
|
||||
else:
|
||||
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
|
||||
else:
|
||||
if device is None:
|
||||
device = bands.device
|
||||
if dtype is None:
|
||||
dtype = bands.dtype
|
||||
|
||||
if in_pixels:
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
||||
else:
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
||||
grid = grid.unsqueeze(-1)
|
||||
pos = grid * bands
|
||||
|
||||
pos_sin, pos_cos = pos.sin(), pos.cos()
|
||||
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
|
||||
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
|
||||
if concat_out:
|
||||
out = torch.cat(out, dim=-1)
|
||||
return out
|
||||
|
||||
|
||||
class FourierEmbed(nn.Module):
|
||||
|
||||
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
|
||||
super().__init__()
|
||||
self.max_res = max_res
|
||||
self.num_bands = num_bands
|
||||
self.concat_grid = concat_grid
|
||||
self.keep_spatial = keep_spatial
|
||||
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
B, C = x.shape[:2]
|
||||
feat_shape = x.shape[2:]
|
||||
emb = build_fourier_pos_embed(
|
||||
feat_shape,
|
||||
self.bands,
|
||||
include_grid=self.concat_grid,
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
|
||||
batch_expand = (B,) + (-1,) * (x.ndim - 1)
|
||||
|
||||
# FIXME support nD
|
||||
if self.keep_spatial:
|
||||
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
|
||||
else:
|
||||
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
|
||||
x = x.reshape(B, feat_shape.numel(), -1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rot(x):
|
||||
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
|
||||
|
||||
|
||||
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
|
||||
return x * cos_emb + rot(x) * sin_emb
|
||||
|
||||
|
||||
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = [x]
|
||||
return [t * cos_emb + rot(t) * sin_emb for t in x]
|
||||
|
||||
|
||||
def apply_rot_embed_split(x: torch.Tensor, emb):
|
||||
split = emb.shape[-1] // 2
|
||||
return x * emb[:, :split] + rot(x) * emb[:, split:]
|
||||
|
||||
|
||||
def build_rotary_pos_embed(
|
||||
feat_shape: List[int],
|
||||
bands: Optional[torch.Tensor] = None,
|
||||
dim: int = 64,
|
||||
max_freq: float = 224,
|
||||
linear_bands: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
"""
|
||||
NOTE: shape arg should include spatial dim only
|
||||
"""
|
||||
feat_shape = torch.Size(feat_shape)
|
||||
|
||||
sin_emb, cos_emb = build_fourier_pos_embed(
|
||||
feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands,
|
||||
concat_out=False, device=device, dtype=dtype)
|
||||
N = feat_shape.numel()
|
||||
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
|
||||
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
|
||||
return sin_emb, cos_emb
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
""" Rotary position embedding
|
||||
|
||||
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
|
||||
been well tested, and will likely change. It will be moved to its own file.
|
||||
|
||||
The following impl/resources were referenced for this impl:
|
||||
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
||||
* https://blog.eleuther.ai/rotary-embeddings/
|
||||
"""
|
||||
def __init__(self, dim, max_res=224, linear_bands: bool = False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
|
||||
|
||||
def get_embed(self, shape: List[int]):
|
||||
return build_rotary_pos_embed(shape, self.bands)
|
||||
|
||||
def forward(self, x):
|
||||
# assuming channel-first tensor where spatial dim are >= 2
|
||||
sin_emb, cos_emb = self.get_embed(x.shape[2:])
|
||||
return apply_rot_embed(x, sin_emb, cos_emb)
|
@ -0,0 +1,272 @@
|
||||
""" MobileViT
|
||||
|
||||
Paper:
|
||||
`MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178
|
||||
|
||||
MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below)
|
||||
License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source)
|
||||
|
||||
Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
import math
|
||||
from typing import Union, Callable, Dict, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups
|
||||
from .fx_features import register_notrace_module
|
||||
from .layers import to_2tuple, make_divisible
|
||||
from .vision_transformer import Block as TransformerBlock
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
||||
'mean': (0, 0, 0), 'std': (1, 1, 1),
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
'fixed_input_size': False,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'mobilevit_xxs': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xxs-ad385b40.pth'),
|
||||
'mobilevit_xs': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xs-8fbd6366.pth'),
|
||||
'mobilevit_s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'),
|
||||
'semobilevit_s': _cfg(),
|
||||
}
|
||||
|
||||
|
||||
def _inverted_residual_block(d, c, s, br=4.0):
|
||||
# inverted residual is a bottleneck block with bottle_ratio > 1 applied to in_chs, linear output, gs=1 (depthwise)
|
||||
return ByoBlockCfg(
|
||||
type='bottle', d=d, c=c, s=s, gs=1, br=br,
|
||||
block_kwargs=dict(bottle_in=True, linear_out=True))
|
||||
|
||||
|
||||
def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4, br=4.0):
|
||||
# inverted residual + mobilevit blocks as per MobileViT network
|
||||
return (
|
||||
_inverted_residual_block(d=d, c=c, s=s, br=br),
|
||||
ByoBlockCfg(
|
||||
type='mobilevit', d=1, c=c, s=1,
|
||||
block_kwargs=dict(
|
||||
transformer_dim=transformer_dim,
|
||||
transformer_depth=transformer_depth,
|
||||
patch_size=patch_size)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
model_cfgs = dict(
|
||||
mobilevit_xxs=ByoModelCfg(
|
||||
blocks=(
|
||||
_inverted_residual_block(d=1, c=16, s=1, br=2.0),
|
||||
_inverted_residual_block(d=3, c=24, s=2, br=2.0),
|
||||
_mobilevit_block(d=1, c=48, s=2, transformer_dim=64, transformer_depth=2, patch_size=2, br=2.0),
|
||||
_mobilevit_block(d=1, c=64, s=2, transformer_dim=80, transformer_depth=4, patch_size=2, br=2.0),
|
||||
_mobilevit_block(d=1, c=80, s=2, transformer_dim=96, transformer_depth=3, patch_size=2, br=2.0),
|
||||
),
|
||||
stem_chs=16,
|
||||
stem_type='3x3',
|
||||
stem_pool='',
|
||||
downsample='',
|
||||
act_layer='silu',
|
||||
num_features=320,
|
||||
),
|
||||
|
||||
mobilevit_xs=ByoModelCfg(
|
||||
blocks=(
|
||||
_inverted_residual_block(d=1, c=32, s=1),
|
||||
_inverted_residual_block(d=3, c=48, s=2),
|
||||
_mobilevit_block(d=1, c=64, s=2, transformer_dim=96, transformer_depth=2, patch_size=2),
|
||||
_mobilevit_block(d=1, c=80, s=2, transformer_dim=120, transformer_depth=4, patch_size=2),
|
||||
_mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=3, patch_size=2),
|
||||
),
|
||||
stem_chs=16,
|
||||
stem_type='3x3',
|
||||
stem_pool='',
|
||||
downsample='',
|
||||
act_layer='silu',
|
||||
num_features=384,
|
||||
),
|
||||
|
||||
mobilevit_s=ByoModelCfg(
|
||||
blocks=(
|
||||
_inverted_residual_block(d=1, c=32, s=1),
|
||||
_inverted_residual_block(d=3, c=64, s=2),
|
||||
_mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=2, patch_size=2),
|
||||
_mobilevit_block(d=1, c=128, s=2, transformer_dim=192, transformer_depth=4, patch_size=2),
|
||||
_mobilevit_block(d=1, c=160, s=2, transformer_dim=240, transformer_depth=3, patch_size=2),
|
||||
),
|
||||
stem_chs=16,
|
||||
stem_type='3x3',
|
||||
stem_pool='',
|
||||
downsample='',
|
||||
act_layer='silu',
|
||||
num_features=640,
|
||||
),
|
||||
|
||||
semobilevit_s=ByoModelCfg(
|
||||
blocks=(
|
||||
_inverted_residual_block(d=1, c=32, s=1),
|
||||
_inverted_residual_block(d=3, c=64, s=2),
|
||||
_mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=2, patch_size=2),
|
||||
_mobilevit_block(d=1, c=128, s=2, transformer_dim=192, transformer_depth=4, patch_size=2),
|
||||
_mobilevit_block(d=1, c=160, s=2, transformer_dim=240, transformer_depth=3, patch_size=2),
|
||||
),
|
||||
stem_chs=16,
|
||||
stem_type='3x3',
|
||||
stem_pool='',
|
||||
downsample='',
|
||||
attn_layer='se',
|
||||
attn_kwargs=dict(rd_ratio=1/8),
|
||||
num_features=640,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register_notrace_module
|
||||
class MobileViTBlock(nn.Module):
|
||||
""" MobileViT block
|
||||
Paper: https://arxiv.org/abs/2110.02178?context=cs.LG
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: Optional[int] = None,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
bottle_ratio: float = 1.0,
|
||||
group_size: Optional[int] = None,
|
||||
dilation: Tuple[int, int] = (1, 1),
|
||||
mlp_ratio: float = 2.0,
|
||||
transformer_dim: Optional[int] = None,
|
||||
transformer_depth: int = 2,
|
||||
patch_size: int = 8,
|
||||
num_heads: int = 4,
|
||||
attn_drop: float = 0.,
|
||||
drop: int = 0.,
|
||||
no_fusion: bool = False,
|
||||
drop_path_rate: float = 0.,
|
||||
layers: LayerFn = None,
|
||||
transformer_norm_layer: Callable = nn.LayerNorm,
|
||||
downsample: str = ''
|
||||
):
|
||||
super(MobileViTBlock, self).__init__()
|
||||
|
||||
layers = layers or LayerFn()
|
||||
groups = num_groups(group_size, in_chs)
|
||||
out_chs = out_chs or in_chs
|
||||
transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs)
|
||||
|
||||
self.conv_kxk = layers.conv_norm_act(
|
||||
in_chs, in_chs, kernel_size=kernel_size,
|
||||
stride=stride, groups=groups, dilation=dilation[0])
|
||||
self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False)
|
||||
|
||||
self.transformer = nn.Sequential(*[
|
||||
TransformerBlock(
|
||||
transformer_dim, mlp_ratio=mlp_ratio, num_heads=num_heads, qkv_bias=True,
|
||||
attn_drop=attn_drop, drop=drop, drop_path=drop_path_rate,
|
||||
act_layer=layers.act, norm_layer=transformer_norm_layer)
|
||||
for _ in range(transformer_depth)
|
||||
])
|
||||
self.norm = transformer_norm_layer(transformer_dim)
|
||||
|
||||
self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1)
|
||||
|
||||
if no_fusion:
|
||||
self.conv_fusion = None
|
||||
else:
|
||||
self.conv_fusion = layers.conv_norm_act(in_chs + out_chs, out_chs, kernel_size=kernel_size, stride=1)
|
||||
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.patch_area = self.patch_size[0] * self.patch_size[1]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shortcut = x
|
||||
|
||||
# Local representation
|
||||
x = self.conv_kxk(x)
|
||||
x = self.conv_1x1(x)
|
||||
|
||||
# Unfold (feature map -> patches)
|
||||
patch_h, patch_w = self.patch_size
|
||||
B, C, H, W = x.shape
|
||||
new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
|
||||
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
|
||||
num_patches = num_patch_h * num_patch_w # N
|
||||
interpolate = False
|
||||
if new_h != H or new_w != W:
|
||||
# Note: Padding can be done, but then it needs to be handled in attention function.
|
||||
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
||||
interpolate = True
|
||||
|
||||
# [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w]
|
||||
x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, patch_w).transpose(1, 2)
|
||||
# [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w
|
||||
x = x.reshape(B, C, num_patches, self.patch_area).transpose(1, 3).reshape(B * self.patch_area, num_patches, -1)
|
||||
|
||||
# Global representations
|
||||
x = self.transformer(x)
|
||||
x = self.norm(x)
|
||||
|
||||
# Fold (patch -> feature map)
|
||||
# [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w]
|
||||
x = x.contiguous().view(B, self.patch_area, num_patches, -1)
|
||||
x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, patch_h, patch_w)
|
||||
# [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W]
|
||||
x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||
if interpolate:
|
||||
x = F.interpolate(x, size=(H, W), mode="bilinear", align_corners=False)
|
||||
|
||||
x = self.conv_proj(x)
|
||||
if self.conv_fusion is not None:
|
||||
x = self.conv_fusion(torch.cat((shortcut, x), dim=1))
|
||||
return x
|
||||
|
||||
|
||||
register_block('mobilevit', MobileViTBlock)
|
||||
|
||||
|
||||
def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
ByobNet, variant, pretrained,
|
||||
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevit_xxs(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevit_xs(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevit_xs', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevit_s(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevit_s', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def semobilevit_s(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs)
|
@ -0,0 +1,322 @@
|
||||
""" PoolFormer implementation
|
||||
|
||||
Paper: `PoolFormer: MetaFormer is Actually What You Need for Vision` - https://arxiv.org/abs/2111.11418
|
||||
|
||||
Code adapted from official impl at https://github.com/sail-sg/poolformer, original copyright in comment below
|
||||
|
||||
Modifications and additions for timm by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
# Copyright 2021 Garena Online Private Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, checkpoint_seq
|
||||
from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .95, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
poolformer_s12=_cfg(
|
||||
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar',
|
||||
crop_pct=0.9),
|
||||
poolformer_s24=_cfg(
|
||||
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar',
|
||||
crop_pct=0.9),
|
||||
poolformer_s36=_cfg(
|
||||
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar',
|
||||
crop_pct=0.9),
|
||||
poolformer_m36=_cfg(
|
||||
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar',
|
||||
crop_pct=0.95),
|
||||
poolformer_m48=_cfg(
|
||||
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar',
|
||||
crop_pct=0.95),
|
||||
)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Patch Embedding that is implemented by a layer of conv.
|
||||
Input: tensor in shape [B, C, H, W]
|
||||
Output: tensor in shape [B, C, H/stride, W/stride]
|
||||
"""
|
||||
|
||||
def __init__(self, in_chs=3, embed_dim=768, patch_size=16, stride=16, padding=0, norm_layer=None):
|
||||
super().__init__()
|
||||
patch_size = to_2tuple(patch_size)
|
||||
stride = to_2tuple(stride)
|
||||
padding = to_2tuple(padding)
|
||||
self.proj = nn.Conv2d(in_chs, embed_dim, kernel_size=patch_size, stride=stride, padding=padding)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class GroupNorm1(nn.GroupNorm):
|
||||
""" Group Normalization with 1 group.
|
||||
Input: tensor in shape [B, C, H, W]
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels, **kwargs):
|
||||
super().__init__(1, num_channels, **kwargs)
|
||||
|
||||
|
||||
class Pooling(nn.Module):
|
||||
def __init__(self, pool_size=3):
|
||||
super().__init__()
|
||||
self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.pool(x) - x
|
||||
|
||||
|
||||
class PoolFormerBlock(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
dim: embedding dim
|
||||
pool_size: pooling size
|
||||
mlp_ratio: mlp expansion ratio
|
||||
act_layer: activation
|
||||
norm_layer: normalization
|
||||
drop: dropout rate
|
||||
drop path: Stochastic Depth, refer to https://arxiv.org/abs/1603.09382
|
||||
use_layer_scale, --layer_scale_init_value: LayerScale, refer to https://arxiv.org/abs/2103.17239
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, pool_size=3, mlp_ratio=4.,
|
||||
act_layer=nn.GELU, norm_layer=GroupNorm1,
|
||||
drop=0., drop_path=0., layer_scale_init_value=1e-5):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.token_mixer = Pooling(pool_size=pool_size)
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = ConvMlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
if layer_scale_init_value:
|
||||
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
||||
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
||||
else:
|
||||
self.layer_scale_1 = None
|
||||
self.layer_scale_2 = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.layer_scale_1 is not None:
|
||||
x = x + self.drop_path1(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x)))
|
||||
x = x + self.drop_path2(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path1(self.token_mixer(self.norm1(x)))
|
||||
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
def basic_blocks(
|
||||
dim, index, layers,
|
||||
pool_size=3, mlp_ratio=4.,
|
||||
act_layer=nn.GELU, norm_layer=GroupNorm1,
|
||||
drop_rate=.0, drop_path_rate=0.,
|
||||
layer_scale_init_value=1e-5,
|
||||
):
|
||||
""" generate PoolFormer blocks for a stage """
|
||||
blocks = []
|
||||
for block_idx in range(layers[index]):
|
||||
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
|
||||
blocks.append(PoolFormerBlock(
|
||||
dim, pool_size=pool_size, mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer, norm_layer=norm_layer,
|
||||
drop=drop_rate, drop_path=block_dpr,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
))
|
||||
blocks = nn.Sequential(*blocks)
|
||||
return blocks
|
||||
|
||||
|
||||
class PoolFormer(nn.Module):
|
||||
""" PoolFormer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layers,
|
||||
embed_dims=(64, 128, 320, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
downsamples=(True, True, True, True),
|
||||
pool_size=3,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
norm_layer=GroupNorm1,
|
||||
act_layer=nn.GELU,
|
||||
in_patch_size=7,
|
||||
in_stride=4,
|
||||
in_pad=2,
|
||||
down_patch_size=3,
|
||||
down_stride=2,
|
||||
down_pad=1,
|
||||
drop_rate=0., drop_path_rate=0.,
|
||||
layer_scale_init_value=1e-5,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = embed_dims[-1]
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=in_patch_size, stride=in_stride, padding=in_pad,
|
||||
in_chs=in_chans, embed_dim=embed_dims[0])
|
||||
|
||||
# set the main block in network
|
||||
network = []
|
||||
for i in range(len(layers)):
|
||||
network.append(basic_blocks(
|
||||
embed_dims[i], i, layers,
|
||||
pool_size=pool_size, mlp_ratio=mlp_ratios[i],
|
||||
act_layer=act_layer, norm_layer=norm_layer,
|
||||
drop_rate=drop_rate, drop_path_rate=drop_path_rate,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
)
|
||||
if i < len(layers) - 1 and (downsamples[i] or embed_dims[i] != embed_dims[i + 1]):
|
||||
# downsampling between stages
|
||||
network.append(PatchEmbed(
|
||||
in_chs=embed_dims[i], embed_dim=embed_dims[i + 1],
|
||||
patch_size=down_patch_size, stride=down_stride, padding=down_pad)
|
||||
)
|
||||
|
||||
self.network = nn.Sequential(*network)
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
# init for classification
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
stem=r'^patch_embed', # stem and embed
|
||||
blocks=[
|
||||
(r'^network\.(\d+)\.(\d+)', None),
|
||||
(r'^network\.(\d+)', (0,)),
|
||||
(r'^norm', (99999,))
|
||||
],
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.network(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean([-2, -1])
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_poolformer(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
model = build_model_with_cfg(PoolFormer, variant, pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def poolformer_s12(pretrained=False, **kwargs):
|
||||
""" PoolFormer-S12 model, Params: 12M """
|
||||
model = _create_poolformer('poolformer_s12', pretrained=pretrained, layers=(2, 2, 6, 2), **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def poolformer_s24(pretrained=False, **kwargs):
|
||||
""" PoolFormer-S24 model, Params: 21M """
|
||||
model = _create_poolformer('poolformer_s24', pretrained=pretrained, layers=(4, 4, 12, 4), **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def poolformer_s36(pretrained=False, **kwargs):
|
||||
""" PoolFormer-S36 model, Params: 31M """
|
||||
model = _create_poolformer(
|
||||
'poolformer_s36', pretrained=pretrained, layers=(6, 6, 18, 6), layer_scale_init_value=1e-6, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def poolformer_m36(pretrained=False, **kwargs):
|
||||
""" PoolFormer-M36 model, Params: 56M """
|
||||
layers = (6, 6, 18, 6)
|
||||
embed_dims = (96, 192, 384, 768)
|
||||
model = _create_poolformer(
|
||||
'poolformer_m36', pretrained=pretrained, layers=layers, embed_dims=embed_dims,
|
||||
layer_scale_init_value=1e-6, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def poolformer_m48(pretrained=False, **kwargs):
|
||||
""" PoolFormer-M48 model, Params: 73M """
|
||||
layers = (8, 8, 24, 8)
|
||||
embed_dims = (96, 192, 384, 768)
|
||||
model = _create_poolformer(
|
||||
'poolformer_m48', pretrained=pretrained, layers=layers, embed_dims=embed_dims,
|
||||
layer_scale_init_value=1e-6, **kwargs)
|
||||
return model
|
@ -0,0 +1,750 @@
|
||||
""" Vision OutLOoker (VOLO) implementation
|
||||
|
||||
Paper: `VOLO: Vision Outlooker for Visual Recognition` - https://arxiv.org/abs/2106.13112
|
||||
|
||||
Code adapted from official impl at https://github.com/sail-sg/volo, original copyright in comment below
|
||||
|
||||
Modifications and additions for timm by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
# Copyright 2021 Sea Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.helpers import build_model_with_cfg
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .96, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.conv.0', 'classifier': ('head', 'aux_head'),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'volo_d1_224': _cfg(
|
||||
url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar',
|
||||
crop_pct=0.96),
|
||||
'volo_d1_384': _cfg(
|
||||
url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar',
|
||||
crop_pct=1.0, input_size=(3, 384, 384)),
|
||||
'volo_d2_224': _cfg(
|
||||
url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar',
|
||||
crop_pct=0.96),
|
||||
'volo_d2_384': _cfg(
|
||||
url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar',
|
||||
crop_pct=1.0, input_size=(3, 384, 384)),
|
||||
'volo_d3_224': _cfg(
|
||||
url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar',
|
||||
crop_pct=0.96),
|
||||
'volo_d3_448': _cfg(
|
||||
url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar',
|
||||
crop_pct=1.0, input_size=(3, 448, 448)),
|
||||
'volo_d4_224': _cfg(
|
||||
url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar',
|
||||
crop_pct=0.96),
|
||||
'volo_d4_448': _cfg(
|
||||
url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar',
|
||||
crop_pct=1.15, input_size=(3, 448, 448)),
|
||||
'volo_d5_224': _cfg(
|
||||
url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar',
|
||||
crop_pct=0.96),
|
||||
'volo_d5_448': _cfg(
|
||||
url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar',
|
||||
crop_pct=1.15, input_size=(3, 448, 448)),
|
||||
'volo_d5_512': _cfg(
|
||||
url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar',
|
||||
crop_pct=1.15, input_size=(3, 512, 512)),
|
||||
}
|
||||
|
||||
|
||||
class OutlookAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
head_dim = dim // num_heads
|
||||
self.num_heads = num_heads
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads)
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)
|
||||
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)
|
||||
|
||||
def forward(self, x):
|
||||
B, H, W, C = x.shape
|
||||
|
||||
v = self.v(x).permute(0, 3, 1, 2) # B, C, H, W
|
||||
|
||||
h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
|
||||
v = self.unfold(v).reshape(
|
||||
B, self.num_heads, C // self.num_heads,
|
||||
self.kernel_size * self.kernel_size, h * w).permute(0, 1, 4, 3, 2) # B,H,N,kxk,C/H
|
||||
|
||||
attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
||||
attn = self.attn(attn).reshape(
|
||||
B, h * w, self.num_heads, self.kernel_size * self.kernel_size,
|
||||
self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4) # B,H,N,kxk,kxk
|
||||
attn = attn * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size, h * w)
|
||||
x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size, padding=self.padding, stride=self.stride)
|
||||
|
||||
x = self.proj(x.permute(0, 2, 3, 1))
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Outlooker(nn.Module):
|
||||
def __init__(
|
||||
self, dim, kernel_size, padding, stride=1, num_heads=1, mlp_ratio=3., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, qkv_bias=False
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = OutlookAttention(
|
||||
dim, num_heads, kernel_size=kernel_size,
|
||||
padding=padding, stride=stride,
|
||||
qkv_bias=qkv_bias, attn_drop=attn_drop)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, H, W, C = x.shape
|
||||
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, H, W, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
|
||||
attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop)
|
||||
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class ClassAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads=8, head_dim=None, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
if head_dim is not None:
|
||||
self.head_dim = head_dim
|
||||
else:
|
||||
head_dim = dim // num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.kv = nn.Linear(dim, self.head_dim * self.num_heads * 2, bias=qkv_bias)
|
||||
self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(self.head_dim * self.num_heads, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
|
||||
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv.unbind(0)
|
||||
q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim)
|
||||
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
cls_embed = (attn @ v).transpose(1, 2).reshape(B, 1, self.head_dim * self.num_heads)
|
||||
cls_embed = self.proj(cls_embed)
|
||||
cls_embed = self.proj_drop(cls_embed)
|
||||
return cls_embed
|
||||
|
||||
|
||||
class ClassBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, head_dim=None, mlp_ratio=4., qkv_bias=False,
|
||||
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = ClassAttention(
|
||||
dim, num_heads=num_heads, head_dim=head_dim, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
# NOTE: drop path for stochastic depth
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
cls_embed = x[:, :1]
|
||||
cls_embed = cls_embed + self.drop_path(self.attn(self.norm1(x)))
|
||||
cls_embed = cls_embed + self.drop_path(self.mlp(self.norm2(cls_embed)))
|
||||
return torch.cat([cls_embed, x[:, 1:]], dim=1)
|
||||
|
||||
|
||||
def get_block(block_type, **kargs):
|
||||
if block_type == 'ca':
|
||||
return ClassBlock(**kargs)
|
||||
|
||||
|
||||
def rand_bbox(size, lam, scale=1):
|
||||
"""
|
||||
get bounding box as token labeling (https://github.com/zihangJiang/TokenLabeling)
|
||||
return: bounding box
|
||||
"""
|
||||
W = size[1] // scale
|
||||
H = size[2] // scale
|
||||
cut_rat = np.sqrt(1. - lam)
|
||||
cut_w = np.int(W * cut_rat)
|
||||
cut_h = np.int(H * cut_rat)
|
||||
|
||||
# uniform
|
||||
cx = np.random.randint(W)
|
||||
cy = np.random.randint(H)
|
||||
|
||||
bbx1 = np.clip(cx - cut_w // 2, 0, W)
|
||||
bby1 = np.clip(cy - cut_h // 2, 0, H)
|
||||
bbx2 = np.clip(cx + cut_w // 2, 0, W)
|
||||
bby2 = np.clip(cy + cut_h // 2, 0, H)
|
||||
|
||||
return bbx1, bby1, bbx2, bby2
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding.
|
||||
Different with ViT use 1 conv layer, we use 4 conv layers to do patch embedding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, stem_conv=False, stem_stride=1,
|
||||
patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384):
|
||||
super().__init__()
|
||||
assert patch_size in [4, 8, 16]
|
||||
if stem_conv:
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
else:
|
||||
self.conv = None
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride)
|
||||
self.num_patches = (img_size // patch_size) * (img_size // patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
if self.conv is not None:
|
||||
x = self.conv(x)
|
||||
x = self.proj(x) # B, C, H, W
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
""" Image to Patch Embedding, downsampling between stage1 and stage2
|
||||
"""
|
||||
|
||||
def __init__(self, in_embed_dim, out_embed_dim, patch_size=2):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.proj(x) # B, C, H, W
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
return x
|
||||
|
||||
|
||||
def outlooker_blocks(
|
||||
block_fn, index, dim, layers, num_heads=1, kernel_size=3, padding=1, stride=2,
|
||||
mlp_ratio=3., qkv_bias=False, attn_drop=0, drop_path_rate=0., **kwargs):
|
||||
"""
|
||||
generate outlooker layer in stage1
|
||||
return: outlooker layers
|
||||
"""
|
||||
blocks = []
|
||||
for block_idx in range(layers[index]):
|
||||
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
|
||||
blocks.append(
|
||||
block_fn(
|
||||
dim, kernel_size=kernel_size, padding=padding,
|
||||
stride=stride, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, attn_drop=attn_drop, drop_path=block_dpr))
|
||||
blocks = nn.Sequential(*blocks)
|
||||
return blocks
|
||||
|
||||
|
||||
def transformer_blocks(
|
||||
block_fn, index, dim, layers, num_heads, mlp_ratio=3.,
|
||||
qkv_bias=False, attn_drop=0, drop_path_rate=0., **kwargs):
|
||||
"""
|
||||
generate transformer layers in stage2
|
||||
return: transformer layers
|
||||
"""
|
||||
blocks = []
|
||||
for block_idx in range(layers[index]):
|
||||
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
|
||||
blocks.append(
|
||||
block_fn(
|
||||
dim, num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=block_dpr))
|
||||
blocks = nn.Sequential(*blocks)
|
||||
return blocks
|
||||
|
||||
|
||||
class VOLO(nn.Module):
|
||||
"""
|
||||
Vision Outlooker, the main class of our model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layers,
|
||||
img_size=224,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='token',
|
||||
patch_size=8,
|
||||
stem_hidden_dim=64,
|
||||
embed_dims=None,
|
||||
num_heads=None,
|
||||
downsamples=(True, False, False, False),
|
||||
outlook_attention=(True, False, False, False),
|
||||
mlp_ratio=3.0,
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
post_layers=('ca', 'ca'),
|
||||
use_aux_head=True,
|
||||
use_mix_token=False,
|
||||
pooling_scale=2,
|
||||
):
|
||||
super().__init__()
|
||||
num_layers = len(layers)
|
||||
mlp_ratio = to_ntuple(num_layers)(mlp_ratio)
|
||||
img_size = to_2tuple(img_size)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.mix_token = use_mix_token
|
||||
self.pooling_scale = pooling_scale
|
||||
self.num_features = embed_dims[-1]
|
||||
if use_mix_token: # enable token mixing, see token labeling for details.
|
||||
self.beta = 1.0
|
||||
assert global_pool == 'token', "return all tokens if mix_token is enabled"
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
stem_conv=True, stem_stride=2, patch_size=patch_size,
|
||||
in_chans=in_chans, hidden_dim=stem_hidden_dim,
|
||||
embed_dim=embed_dims[0])
|
||||
|
||||
# inital positional encoding, we add positional encoding after outlooker blocks
|
||||
patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale)
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1]))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# set the main block in network
|
||||
network = []
|
||||
for i in range(len(layers)):
|
||||
if outlook_attention[i]:
|
||||
# stage 1
|
||||
stage = outlooker_blocks(
|
||||
Outlooker, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i],
|
||||
qkv_bias=qkv_bias, attn_drop=attn_drop_rate, norm_layer=norm_layer)
|
||||
network.append(stage)
|
||||
else:
|
||||
# stage 2
|
||||
stage = transformer_blocks(
|
||||
Transformer, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias,
|
||||
drop_path_rate=drop_path_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
|
||||
network.append(stage)
|
||||
|
||||
if downsamples[i]:
|
||||
# downsampling between two stages
|
||||
network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2))
|
||||
|
||||
self.network = nn.ModuleList(network)
|
||||
|
||||
# set post block, for example, class attention layers
|
||||
self.post_network = None
|
||||
if post_layers is not None:
|
||||
self.post_network = nn.ModuleList(
|
||||
[
|
||||
get_block(
|
||||
post_layers[i],
|
||||
dim=embed_dims[-1],
|
||||
num_heads=num_heads[-1],
|
||||
mlp_ratio=mlp_ratio[-1],
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=0.,
|
||||
norm_layer=norm_layer)
|
||||
for i in range(len(post_layers))
|
||||
])
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1]))
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
# set output type
|
||||
if use_aux_head:
|
||||
self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
else:
|
||||
self.aux_head = None
|
||||
self.norm = norm_layer(self.num_features)
|
||||
|
||||
# Classifier head
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
|
||||
blocks=[
|
||||
(r'^network\.(\d+)\.(\d+)', None),
|
||||
(r'^network\.(\d+)', (0,)),
|
||||
],
|
||||
blocks2=[
|
||||
(r'^cls_token', (0,)),
|
||||
(r'^post_network\.(\d+)', None),
|
||||
(r'^norm', (99999,))
|
||||
],
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
if self.aux_head is not None:
|
||||
self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_tokens(self, x):
|
||||
for idx, block in enumerate(self.network):
|
||||
if idx == 2:
|
||||
# add positional encoding after outlooker blocks
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(block, x)
|
||||
else:
|
||||
x = block(x)
|
||||
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(B, -1, C)
|
||||
return x
|
||||
|
||||
def forward_cls(self, x):
|
||||
B, N, C = x.shape
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat([cls_tokens, x], dim=1)
|
||||
for block in self.post_network:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(block, x)
|
||||
else:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
def forward_train(self, x):
|
||||
""" A separate forward fn for training with mix_token (if a train script supports).
|
||||
Combining multiple modes in as single forward with different return types is torchscript hell.
|
||||
"""
|
||||
x = self.patch_embed(x)
|
||||
x = x.permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
|
||||
|
||||
# mix token, see token labeling for details.
|
||||
if self.mix_token and self.training:
|
||||
lam = np.random.beta(self.beta, self.beta)
|
||||
patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale
|
||||
bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
|
||||
temp_x = x.clone()
|
||||
sbbx1, sbby1 = self.pooling_scale * bbx1, self.pooling_scale * bby1
|
||||
sbbx2, sbby2 = self.pooling_scale * bbx2, self.pooling_scale * bby2
|
||||
temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
|
||||
x = temp_x
|
||||
else:
|
||||
bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0
|
||||
|
||||
# step2: tokens learning in the two stages
|
||||
x = self.forward_tokens(x)
|
||||
|
||||
# step3: post network, apply class attention or not
|
||||
if self.post_network is not None:
|
||||
x = self.forward_cls(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if self.global_pool == 'avg':
|
||||
x_cls = x.mean(dim=1)
|
||||
elif self.global_pool == 'token':
|
||||
x_cls = x[:, 0]
|
||||
else:
|
||||
x_cls = x
|
||||
|
||||
if self.aux_head is None:
|
||||
return x_cls
|
||||
|
||||
x_aux = self.aux_head(x[:, 1:]) # generate classes in all feature tokens, see token labeling
|
||||
if not self.training:
|
||||
return x_cls + 0.5 * x_aux.max(1)[0]
|
||||
|
||||
if self.mix_token and self.training: # reverse "mix token", see token labeling for details.
|
||||
x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])
|
||||
temp_x = x_aux.clone()
|
||||
temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
|
||||
x_aux = temp_x
|
||||
x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])
|
||||
|
||||
# return these: 1. class token, 2. classes from all feature tokens, 3. bounding box
|
||||
return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
|
||||
|
||||
# step2: tokens learning in the two stages
|
||||
x = self.forward_tokens(x)
|
||||
|
||||
# step3: post network, apply class attention or not
|
||||
if self.post_network is not None:
|
||||
x = self.forward_cls(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
out = x.mean(dim=1)
|
||||
elif self.global_pool == 'token':
|
||||
out = x[:, 0]
|
||||
else:
|
||||
out = x
|
||||
if pre_logits:
|
||||
return out
|
||||
out = self.head(out)
|
||||
if self.aux_head is not None:
|
||||
# generate classes in all feature tokens, see token labeling
|
||||
aux = self.aux_head(x[:, 1:])
|
||||
out = out + 0.5 * aux.max(1)[0]
|
||||
return out
|
||||
|
||||
def forward(self, x):
|
||||
""" simplified forward (without mix token training) """
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_volo(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
return build_model_with_cfg(VOLO, variant, pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def volo_d1_224(pretrained=False, **kwargs):
|
||||
""" VOLO-D1 model, Params: 27M """
|
||||
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
|
||||
model = _create_volo('volo_d1_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def volo_d1_384(pretrained=False, **kwargs):
|
||||
""" VOLO-D1 model, Params: 27M """
|
||||
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
|
||||
model = _create_volo('volo_d1_384', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def volo_d2_224(pretrained=False, **kwargs):
|
||||
""" VOLO-D2 model, Params: 59M """
|
||||
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
||||
model = _create_volo('volo_d2_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def volo_d2_384(pretrained=False, **kwargs):
|
||||
""" VOLO-D2 model, Params: 59M """
|
||||
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
||||
model = _create_volo('volo_d2_384', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def volo_d3_224(pretrained=False, **kwargs):
|
||||
""" VOLO-D3 model, Params: 86M """
|
||||
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
||||
model = _create_volo('volo_d3_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def volo_d3_448(pretrained=False, **kwargs):
|
||||
""" VOLO-D3 model, Params: 86M """
|
||||
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
||||
model = _create_volo('volo_d3_448', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def volo_d4_224(pretrained=False, **kwargs):
|
||||
""" VOLO-D4 model, Params: 193M """
|
||||
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
|
||||
model = _create_volo('volo_d4_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def volo_d4_448(pretrained=False, **kwargs):
|
||||
""" VOLO-D4 model, Params: 193M """
|
||||
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
|
||||
model = _create_volo('volo_d4_448', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def volo_d5_224(pretrained=False, **kwargs):
|
||||
""" VOLO-D5 model, Params: 296M
|
||||
stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5
|
||||
"""
|
||||
model_args = dict(
|
||||
layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
|
||||
mlp_ratio=4, stem_hidden_dim=128, **kwargs)
|
||||
model = _create_volo('volo_d5_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def volo_d5_448(pretrained=False, **kwargs):
|
||||
""" VOLO-D5 model, Params: 296M
|
||||
stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5
|
||||
"""
|
||||
model_args = dict(
|
||||
layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
|
||||
mlp_ratio=4, stem_hidden_dim=128, **kwargs)
|
||||
model = _create_volo('volo_d5_448', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def volo_d5_512(pretrained=False, **kwargs):
|
||||
""" VOLO-D5 model, Params: 296M
|
||||
stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5
|
||||
"""
|
||||
model_args = dict(
|
||||
layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
|
||||
mlp_ratio=4, stem_hidden_dim=128, **kwargs)
|
||||
model = _create_volo('volo_d5_512', pretrained=pretrained, **model_args)
|
||||
return model
|
Loading…
Reference in new issue