Add MobileVitV2 support. Fix #1332. Move GroupNorm1 to common layers (used in poolformer + mobilevitv2). Keep ol custom ConvNeXt LayerNorm2d impl as LayerNormExp2d for reference.

pull/1327/head
Ross Wightman 3 years ago
parent 06307b8b41
commit eca09b8642

@ -25,7 +25,7 @@ from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, LayerNorm2d from .norm import GroupNorm, GroupNorm1, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed from .patch_embed import PatchEmbed

@ -22,7 +22,7 @@ def get_attn(attn_type):
if isinstance(attn_type, torch.nn.Module): if isinstance(attn_type, torch.nn.Module):
return attn_type return attn_type
module_cls = None module_cls = None
if attn_type is not None: if attn_type:
if isinstance(attn_type, str): if isinstance(attn_type, str):
attn_type = attn_type.lower() attn_type = attn_type.lower()
# Lightweight attention modules (channel and/or coarse spatial). # Lightweight attention modules (channel and/or coarse spatial).

@ -14,11 +14,59 @@ class GroupNorm(nn.GroupNorm):
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, *]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
class LayerNorm2d(nn.LayerNorm): class LayerNorm2d(nn.LayerNorm):
""" LayerNorm for channels of '2D' spatial BCHW tensors """ """ LayerNorm for channels of '2D' spatial NCHW tensors """
def __init__(self, num_channels, eps=1e-6): def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps) super().__init__(num_channels, eps=eps, elementwise_affine=affine)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm( return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
def _is_contiguous(tensor: torch.Tensor) -> bool:
# jit is oh so lovely :/
# if torch.jit.is_tracing():
# return True
if torch.jit.is_scripting():
return tensor.is_contiguous()
else:
return tensor.is_contiguous(memory_format=torch.contiguous_format)
@torch.jit.script
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
x = (x - u) * torch.rsqrt(s + eps)
x = x * weight[:, None, None] + bias[:, None, None]
return x
class LayerNormExp2d(nn.LayerNorm):
""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
Experimental implementation w/ manual norm for tensors non-contiguous tensors.
This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last
layout. However, benefits are not always clear and can perform worse on other GPUs.
"""
def __init__(self, num_channels, eps=1e-6):
super().__init__(num_channels, eps=eps)
def forward(self, x) -> torch.Tensor:
if _is_contiguous(x):
x = F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
else:
x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
return x

@ -1,7 +1,8 @@
""" MobileViT """ MobileViT
Paper: Paper:
`MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 V1: `MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178
V2: `Separable Self-attention for Mobile Vision Transformers` - https://arxiv.org/abs/2206.02680
MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below) MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below)
License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source) License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source)
@ -13,7 +14,7 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022
# Copyright (C) 2020 Apple Inc. All Rights Reserved. # Copyright (C) 2020 Apple Inc. All Rights Reserved.
# #
import math import math
from typing import Union, Callable, Dict, Tuple, Optional from typing import Union, Callable, Dict, Tuple, Optional, Sequence
import torch import torch
from torch import nn from torch import nn
@ -21,7 +22,7 @@ import torch.nn.functional as F
from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups
from .fx_features import register_notrace_module from .fx_features import register_notrace_module
from .layers import to_2tuple, make_divisible from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath
from .vision_transformer import Block as TransformerBlock from .vision_transformer import Block as TransformerBlock
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .registry import register_model from .registry import register_model
@ -48,6 +49,48 @@ default_cfgs = {
'mobilevit_s': _cfg( 'mobilevit_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'),
'semobilevit_s': _cfg(), 'semobilevit_s': _cfg(),
'mobilevitv2_050': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_050-49951ee2.pth',
crop_pct=0.888),
'mobilevitv2_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_075-b5556ef6.pth',
crop_pct=0.888),
'mobilevitv2_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_100-e464ef3b.pth',
crop_pct=0.888),
'mobilevitv2_125': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_125-0ae35027.pth',
crop_pct=0.888),
'mobilevitv2_150': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150-737c5019.pth',
crop_pct=0.888),
'mobilevitv2_175': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175-16462ee2.pth',
crop_pct=0.888),
'mobilevitv2_200': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200-b3422f67.pth',
crop_pct=0.888),
'mobilevitv2_150_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_in22ft1k-0b555d7b.pth',
crop_pct=0.888),
'mobilevitv2_175_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_in22ft1k-4117fa1f.pth',
crop_pct=0.888),
'mobilevitv2_200_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_in22ft1k-1d7c8927.pth',
crop_pct=0.888),
'mobilevitv2_150_384_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_384_in22ft1k-9e142854.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'mobilevitv2_175_384_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_384_in22ft1k-059cbe56.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'mobilevitv2_200_384_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_384_in22ft1k-32c87503.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
} }
@ -72,6 +115,40 @@ def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4,
) )
def _mobilevitv2_block(d, c, s, transformer_depth, patch_size=2, br=2.0, transformer_br=0.5):
# inverted residual + mobilevit blocks as per MobileViT network
return (
_inverted_residual_block(d=d, c=c, s=s, br=br),
ByoBlockCfg(
type='mobilevit2', d=1, c=c, s=1, br=transformer_br, gs=1,
block_kwargs=dict(
transformer_depth=transformer_depth,
patch_size=patch_size)
)
)
def _mobilevitv2_cfg(multiplier=1.0):
chs = (64, 128, 256, 384, 512)
if multiplier != 1.0:
chs = tuple([int(c * multiplier) for c in chs])
cfg = ByoModelCfg(
blocks=(
_inverted_residual_block(d=1, c=chs[0], s=1, br=2.0),
_inverted_residual_block(d=2, c=chs[1], s=2, br=2.0),
_mobilevitv2_block(d=1, c=chs[2], s=2, transformer_depth=2),
_mobilevitv2_block(d=1, c=chs[3], s=2, transformer_depth=4),
_mobilevitv2_block(d=1, c=chs[4], s=2, transformer_depth=3),
),
stem_chs=int(32 * multiplier),
stem_type='3x3',
stem_pool='',
downsample='',
act_layer='silu',
)
return cfg
model_cfgs = dict( model_cfgs = dict(
mobilevit_xxs=ByoModelCfg( mobilevit_xxs=ByoModelCfg(
blocks=( blocks=(
@ -137,11 +214,19 @@ model_cfgs = dict(
attn_kwargs=dict(rd_ratio=1/8), attn_kwargs=dict(rd_ratio=1/8),
num_features=640, num_features=640,
), ),
mobilevitv2_050=_mobilevitv2_cfg(.50),
mobilevitv2_075=_mobilevitv2_cfg(.75),
mobilevitv2_125=_mobilevitv2_cfg(1.25),
mobilevitv2_100=_mobilevitv2_cfg(1.0),
mobilevitv2_150=_mobilevitv2_cfg(1.5),
mobilevitv2_175=_mobilevitv2_cfg(1.75),
mobilevitv2_200=_mobilevitv2_cfg(2.0),
) )
@register_notrace_module @register_notrace_module
class MobileViTBlock(nn.Module): class MobileVitBlock(nn.Module):
""" MobileViT block """ MobileViT block
Paper: https://arxiv.org/abs/2110.02178?context=cs.LG Paper: https://arxiv.org/abs/2110.02178?context=cs.LG
""" """
@ -165,9 +250,9 @@ class MobileViTBlock(nn.Module):
drop_path_rate: float = 0., drop_path_rate: float = 0.,
layers: LayerFn = None, layers: LayerFn = None,
transformer_norm_layer: Callable = nn.LayerNorm, transformer_norm_layer: Callable = nn.LayerNorm,
downsample: str = '' **kwargs, # eat unused args
): ):
super(MobileViTBlock, self).__init__() super(MobileVitBlock, self).__init__()
layers = layers or LayerFn() layers = layers or LayerFn()
groups = num_groups(group_size, in_chs) groups = num_groups(group_size, in_chs)
@ -241,7 +326,270 @@ class MobileViTBlock(nn.Module):
return x return x
register_block('mobilevit', MobileViTBlock) class LinearSelfAttention(nn.Module):
"""
This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680`
This layer can be used for self- as well as cross-attention.
Args:
embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)`
attn_drop (float): Dropout value for context scores. Default: 0.0
bias (bool): Use bias in learnable layers. Default: True
Shape:
- Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels,
:math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches
- Output: same as the input
.. note::
For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels
in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor,
we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be
expensive on resource-constrained devices) that may be required to convert the unfolded tensor from
channel-first to channel-last format in case of a linear layer.
"""
def __init__(
self,
embed_dim: int,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
self.embed_dim = embed_dim
self.qkv_proj = nn.Conv2d(
in_channels=embed_dim,
out_channels=1 + (2 * embed_dim),
bias=bias,
kernel_size=1,
)
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = nn.Conv2d(
in_channels=embed_dim,
out_channels=embed_dim,
bias=bias,
kernel_size=1,
)
self.out_drop = nn.Dropout(proj_drop)
def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor:
# [B, C, P, N] --> [B, h + 2d, P, N]
qkv = self.qkv_proj(x)
# Project x into query, key and value
# Query --> [B, 1, P, N]
# value, key --> [B, d, P, N]
query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
# apply softmax along N dimension
context_scores = F.softmax(query, dim=-1)
context_scores = self.attn_drop(context_scores)
# Compute context vector
# [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1]
context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F.relu(value) * context_vector.expand_as(value)
out = self.out_proj(out)
out = self.out_drop(out)
return out
@torch.jit.ignore()
def _forward_cross_attn(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
# x --> [B, C, P, N]
# x_prev = [B, C, P, M]
batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape
q_patch_area, q_num_patches = x.shape[-2:]
assert (
kv_patch_area == q_patch_area
), "The number of pixels in a patch for query and key_value should be the same"
# compute query, key, and value
# [B, C, P, M] --> [B, 1 + d, P, M]
qk = F.conv2d(
x_prev,
weight=self.qkv_proj.weight[:self.embed_dim + 1],
bias=self.qkv_proj.bias[:self.embed_dim + 1],
)
# [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M]
query, key = qk.split([1, self.embed_dim], dim=1)
# [B, C, P, N] --> [B, d, P, N]
value = F.conv2d(
x,
weight=self.qkv_proj.weight[self.embed_dim + 1],
bias=self.qkv_proj.bias[self.embed_dim + 1] if self.qkv_proj.bias is not None else None,
)
# apply softmax along M dimension
context_scores = F.softmax(query, dim=-1)
context_scores = self.attn_drop(context_scores)
# compute context vector
# [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1]
context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F.relu(value) * context_vector.expand_as(value)
out = self.out_proj(out)
out = self.out_drop(out)
return out
def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
if x_prev is None:
return self._forward_self_attn(x)
else:
return self._forward_cross_attn(x, x_prev=x_prev)
class LinearTransformerBlock(nn.Module):
"""
This class defines the pre-norm transformer encoder with linear self-attention in `MobileViTv2 paper <>`_
Args:
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, P, N)`
mlp_ratio (float): Inner dimension ratio of the FFN relative to embed_dim
drop (float): Dropout rate. Default: 0.0
attn_drop (float): Dropout rate for attention in multi-head attention. Default: 0.0
drop_path (float): Stochastic depth rate Default: 0.0
norm_layer (Callable): Normalization layer. Default: layer_norm_2d
Shape:
- Input: :math:`(B, C_{in}, P, N)` where :math:`B` is batch size, :math:`C_{in}` is input embedding dim,
:math:`P` is number of pixels in a patch, and :math:`N` is number of patches,
- Output: same shape as the input
"""
def __init__(
self,
embed_dim: int,
mlp_ratio: float = 2.0,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer=None,
norm_layer=None,
) -> None:
super().__init__()
act_layer = act_layer or nn.SiLU
norm_layer = norm_layer or GroupNorm1
self.norm1 = norm_layer(embed_dim)
self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop)
self.drop_path1 = DropPath(drop_path)
self.norm2 = norm_layer(embed_dim)
self.mlp = ConvMlp(
in_features=embed_dim,
hidden_features=int(embed_dim * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.drop_path2 = DropPath(drop_path)
def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
if x_prev is None:
# self-attention
x = x + self.drop_path1(self.attn(self.norm1(x)))
else:
# cross-attention
res = x
x = self.norm1(x) # norm
x = self.attn(x, x_prev) # attn
x = self.drop_path1(x) + res # residual
# Feed forward network
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x
@register_notrace_module
class MobileVitV2Block(nn.Module):
"""
This class defines the `MobileViTv2 block <>`_
"""
def __init__(
self,
in_chs: int,
out_chs: Optional[int] = None,
kernel_size: int = 3,
bottle_ratio: float = 1.0,
group_size: Optional[int] = 1,
dilation: Tuple[int, int] = (1, 1),
mlp_ratio: float = 2.0,
transformer_dim: Optional[int] = None,
transformer_depth: int = 2,
patch_size: int = 8,
attn_drop: float = 0.,
drop: int = 0.,
drop_path_rate: float = 0.,
layers: LayerFn = None,
transformer_norm_layer: Callable = GroupNorm1,
**kwargs, # eat unused args
):
super(MobileVitV2Block, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
out_chs = out_chs or in_chs
transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs)
self.conv_kxk = layers.conv_norm_act(
in_chs, in_chs, kernel_size=kernel_size,
stride=1, groups=groups, dilation=dilation[0])
self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False)
self.transformer = nn.Sequential(*[
LinearTransformerBlock(
transformer_dim,
mlp_ratio=mlp_ratio,
attn_drop=attn_drop,
drop=drop,
drop_path=drop_path_rate,
act_layer=layers.act,
norm_layer=transformer_norm_layer
)
for _ in range(transformer_depth)
])
self.norm = transformer_norm_layer(transformer_dim)
self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False)
self.patch_size = to_2tuple(patch_size)
self.patch_area = self.patch_size[0] * self.patch_size[1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
patch_h, patch_w = self.patch_size
new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
num_patches = num_patch_h * num_patch_w # N
if new_h != H or new_w != W:
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=True)
# Local representation
x = self.conv_kxk(x)
x = self.conv_1x1(x)
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
C = x.shape[1]
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
x = x.reshape(B, C, -1, num_patches)
# Global representations
x = self.transformer(x)
x = self.norm(x)
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
x = self.conv_proj(x)
return x
register_block('mobilevit', MobileVitBlock)
register_block('mobilevit2', MobileVitV2Block)
def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
@ -252,6 +600,14 @@ def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
**kwargs) **kwargs)
def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg(
ByobNet, variant, pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@register_model @register_model
def mobilevit_xxs(pretrained=False, **kwargs): def mobilevit_xxs(pretrained=False, **kwargs):
return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs) return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs)
@ -270,3 +626,74 @@ def mobilevit_s(pretrained=False, **kwargs):
@register_model @register_model
def semobilevit_s(pretrained=False, **kwargs): def semobilevit_s(pretrained=False, **kwargs):
return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_050(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_075(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_075', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_100(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_100', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_125(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_125', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_150(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_150', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_175(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_175', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_200(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_150_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_150_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_175_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_175_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_200_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_200_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_150_384_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_150_384_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_175_384_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_175_384_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_200_384_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_200_384_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs)

@ -26,7 +26,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, checkpoint_seq from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1
from .registry import register_model from .registry import register_model
@ -80,15 +80,6 @@ class PatchEmbed(nn.Module):
return x return x
class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
class Pooling(nn.Module): class Pooling(nn.Module):
def __init__(self, pool_size=3): def __init__(self, pool_size=3):
super().__init__() super().__init__()

Loading…
Cancel
Save