diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index b1f452ff..b9eeec0f 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -25,7 +25,7 @@ from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, LayerNorm2d +from .norm import GroupNorm, GroupNorm1, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index 028c0f75..cc7e91ea 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -22,7 +22,7 @@ def get_attn(attn_type): if isinstance(attn_type, torch.nn.Module): return attn_type module_cls = None - if attn_type is not None: + if attn_type: if isinstance(attn_type, str): attn_type = attn_type.lower() # Lightweight attention modules (channel and/or coarse spatial). diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 345f67bc..1677dbfa 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -14,11 +14,59 @@ class GroupNorm(nn.GroupNorm): return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) +class GroupNorm1(nn.GroupNorm): + """ Group Normalization with 1 group. + Input: tensor in shape [B, C, *] + """ + + def __init__(self, num_channels, **kwargs): + super().__init__(1, num_channels, **kwargs) + + class LayerNorm2d(nn.LayerNorm): - """ LayerNorm for channels of '2D' spatial BCHW tensors """ - def __init__(self, num_channels, eps=1e-6): - super().__init__(num_channels, eps=eps) + """ LayerNorm for channels of '2D' spatial NCHW tensors """ + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + + +def _is_contiguous(tensor: torch.Tensor) -> bool: + # jit is oh so lovely :/ + # if torch.jit.is_tracing(): + # return True + if torch.jit.is_scripting(): + return tensor.is_contiguous() + else: + return tensor.is_contiguous(memory_format=torch.contiguous_format) + + +@torch.jit.script +def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): + s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) + x = (x - u) * torch.rsqrt(s + eps) + x = x * weight[:, None, None] + bias[:, None, None] + return x + + +class LayerNormExp2d(nn.LayerNorm): + """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). + + Experimental implementation w/ manual norm for tensors non-contiguous tensors. + + This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last + layout. However, benefits are not always clear and can perform worse on other GPUs. + """ + + def __init__(self, num_channels, eps=1e-6): + super().__init__(num_channels, eps=eps) + + def forward(self, x) -> torch.Tensor: + if _is_contiguous(x): + x = F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + else: + x = _layer_norm_cf(x, self.weight, self.bias, self.eps) + return x diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index 1c55bd1c..2a3ab924 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -1,7 +1,8 @@ """ MobileViT Paper: -`MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 +V1: `MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 +V2: `Separable Self-attention for Mobile Vision Transformers` - https://arxiv.org/abs/2206.02680 MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below) License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source) @@ -13,7 +14,7 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022 # Copyright (C) 2020 Apple Inc. All Rights Reserved. # import math -from typing import Union, Callable, Dict, Tuple, Optional +from typing import Union, Callable, Dict, Tuple, Optional, Sequence import torch from torch import nn @@ -21,7 +22,7 @@ import torch.nn.functional as F from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups from .fx_features import register_notrace_module -from .layers import to_2tuple, make_divisible +from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath from .vision_transformer import Block as TransformerBlock from .helpers import build_model_with_cfg from .registry import register_model @@ -48,6 +49,48 @@ default_cfgs = { 'mobilevit_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'), 'semobilevit_s': _cfg(), + + 'mobilevitv2_050': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_050-49951ee2.pth', + crop_pct=0.888), + 'mobilevitv2_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_075-b5556ef6.pth', + crop_pct=0.888), + 'mobilevitv2_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_100-e464ef3b.pth', + crop_pct=0.888), + 'mobilevitv2_125': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_125-0ae35027.pth', + crop_pct=0.888), + 'mobilevitv2_150': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150-737c5019.pth', + crop_pct=0.888), + 'mobilevitv2_175': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175-16462ee2.pth', + crop_pct=0.888), + 'mobilevitv2_200': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200-b3422f67.pth', + crop_pct=0.888), + + 'mobilevitv2_150_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_in22ft1k-0b555d7b.pth', + crop_pct=0.888), + 'mobilevitv2_175_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_in22ft1k-4117fa1f.pth', + crop_pct=0.888), + 'mobilevitv2_200_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_in22ft1k-1d7c8927.pth', + crop_pct=0.888), + + 'mobilevitv2_150_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_384_in22ft1k-9e142854.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'mobilevitv2_175_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_384_in22ft1k-059cbe56.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'mobilevitv2_200_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_384_in22ft1k-32c87503.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), } @@ -72,6 +115,40 @@ def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4, ) +def _mobilevitv2_block(d, c, s, transformer_depth, patch_size=2, br=2.0, transformer_br=0.5): + # inverted residual + mobilevit blocks as per MobileViT network + return ( + _inverted_residual_block(d=d, c=c, s=s, br=br), + ByoBlockCfg( + type='mobilevit2', d=1, c=c, s=1, br=transformer_br, gs=1, + block_kwargs=dict( + transformer_depth=transformer_depth, + patch_size=patch_size) + ) + ) + + +def _mobilevitv2_cfg(multiplier=1.0): + chs = (64, 128, 256, 384, 512) + if multiplier != 1.0: + chs = tuple([int(c * multiplier) for c in chs]) + cfg = ByoModelCfg( + blocks=( + _inverted_residual_block(d=1, c=chs[0], s=1, br=2.0), + _inverted_residual_block(d=2, c=chs[1], s=2, br=2.0), + _mobilevitv2_block(d=1, c=chs[2], s=2, transformer_depth=2), + _mobilevitv2_block(d=1, c=chs[3], s=2, transformer_depth=4), + _mobilevitv2_block(d=1, c=chs[4], s=2, transformer_depth=3), + ), + stem_chs=int(32 * multiplier), + stem_type='3x3', + stem_pool='', + downsample='', + act_layer='silu', + ) + return cfg + + model_cfgs = dict( mobilevit_xxs=ByoModelCfg( blocks=( @@ -137,11 +214,19 @@ model_cfgs = dict( attn_kwargs=dict(rd_ratio=1/8), num_features=640, ), + + mobilevitv2_050=_mobilevitv2_cfg(.50), + mobilevitv2_075=_mobilevitv2_cfg(.75), + mobilevitv2_125=_mobilevitv2_cfg(1.25), + mobilevitv2_100=_mobilevitv2_cfg(1.0), + mobilevitv2_150=_mobilevitv2_cfg(1.5), + mobilevitv2_175=_mobilevitv2_cfg(1.75), + mobilevitv2_200=_mobilevitv2_cfg(2.0), ) @register_notrace_module -class MobileViTBlock(nn.Module): +class MobileVitBlock(nn.Module): """ MobileViT block Paper: https://arxiv.org/abs/2110.02178?context=cs.LG """ @@ -165,9 +250,9 @@ class MobileViTBlock(nn.Module): drop_path_rate: float = 0., layers: LayerFn = None, transformer_norm_layer: Callable = nn.LayerNorm, - downsample: str = '' + **kwargs, # eat unused args ): - super(MobileViTBlock, self).__init__() + super(MobileVitBlock, self).__init__() layers = layers or LayerFn() groups = num_groups(group_size, in_chs) @@ -241,7 +326,270 @@ class MobileViTBlock(nn.Module): return x -register_block('mobilevit', MobileViTBlock) +class LinearSelfAttention(nn.Module): + """ + This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680` + This layer can be used for self- as well as cross-attention. + Args: + embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` + attn_drop (float): Dropout value for context scores. Default: 0.0 + bias (bool): Use bias in learnable layers. Default: True + Shape: + - Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels, + :math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches + - Output: same as the input + .. note:: + For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels + in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor, + we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be + expensive on resource-constrained devices) that may be required to convert the unfolded tensor from + channel-first to channel-last format in case of a linear layer. + """ + + def __init__( + self, + embed_dim: int, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + + self.qkv_proj = nn.Conv2d( + in_channels=embed_dim, + out_channels=1 + (2 * embed_dim), + bias=bias, + kernel_size=1, + ) + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Conv2d( + in_channels=embed_dim, + out_channels=embed_dim, + bias=bias, + kernel_size=1, + ) + self.out_drop = nn.Dropout(proj_drop) + + def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor: + # [B, C, P, N] --> [B, h + 2d, P, N] + qkv = self.qkv_proj(x) + + # Project x into query, key and value + # Query --> [B, 1, P, N] + # value, key --> [B, d, P, N] + query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1) + + # apply softmax along N dimension + context_scores = F.softmax(query, dim=-1) + context_scores = self.attn_drop(context_scores) + + # Compute context vector + # [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1] + context_vector = (key * context_scores).sum(dim=-1, keepdim=True) + + # combine context vector with values + # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] + out = F.relu(value) * context_vector.expand_as(value) + out = self.out_proj(out) + out = self.out_drop(out) + return out + + @torch.jit.ignore() + def _forward_cross_attn(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + # x --> [B, C, P, N] + # x_prev = [B, C, P, M] + batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape + q_patch_area, q_num_patches = x.shape[-2:] + + assert ( + kv_patch_area == q_patch_area + ), "The number of pixels in a patch for query and key_value should be the same" + + # compute query, key, and value + # [B, C, P, M] --> [B, 1 + d, P, M] + qk = F.conv2d( + x_prev, + weight=self.qkv_proj.weight[:self.embed_dim + 1], + bias=self.qkv_proj.bias[:self.embed_dim + 1], + ) + + # [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M] + query, key = qk.split([1, self.embed_dim], dim=1) + # [B, C, P, N] --> [B, d, P, N] + value = F.conv2d( + x, + weight=self.qkv_proj.weight[self.embed_dim + 1], + bias=self.qkv_proj.bias[self.embed_dim + 1] if self.qkv_proj.bias is not None else None, + ) + + # apply softmax along M dimension + context_scores = F.softmax(query, dim=-1) + context_scores = self.attn_drop(context_scores) + + # compute context vector + # [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1] + context_vector = (key * context_scores).sum(dim=-1, keepdim=True) + + # combine context vector with values + # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] + out = F.relu(value) * context_vector.expand_as(value) + out = self.out_proj(out) + out = self.out_drop(out) + return out + + def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + if x_prev is None: + return self._forward_self_attn(x) + else: + return self._forward_cross_attn(x, x_prev=x_prev) + + +class LinearTransformerBlock(nn.Module): + """ + This class defines the pre-norm transformer encoder with linear self-attention in `MobileViTv2 paper <>`_ + Args: + embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, P, N)` + mlp_ratio (float): Inner dimension ratio of the FFN relative to embed_dim + drop (float): Dropout rate. Default: 0.0 + attn_drop (float): Dropout rate for attention in multi-head attention. Default: 0.0 + drop_path (float): Stochastic depth rate Default: 0.0 + norm_layer (Callable): Normalization layer. Default: layer_norm_2d + Shape: + - Input: :math:`(B, C_{in}, P, N)` where :math:`B` is batch size, :math:`C_{in}` is input embedding dim, + :math:`P` is number of pixels in a patch, and :math:`N` is number of patches, + - Output: same shape as the input + """ + + def __init__( + self, + embed_dim: int, + mlp_ratio: float = 2.0, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer=None, + norm_layer=None, + ) -> None: + super().__init__() + act_layer = act_layer or nn.SiLU + norm_layer = norm_layer or GroupNorm1 + + self.norm1 = norm_layer(embed_dim) + self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop) + self.drop_path1 = DropPath(drop_path) + + self.norm2 = norm_layer(embed_dim) + self.mlp = ConvMlp( + in_features=embed_dim, + hidden_features=int(embed_dim * mlp_ratio), + act_layer=act_layer, + drop=drop) + self.drop_path2 = DropPath(drop_path) + + def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + if x_prev is None: + # self-attention + x = x + self.drop_path1(self.attn(self.norm1(x))) + else: + # cross-attention + res = x + x = self.norm1(x) # norm + x = self.attn(x, x_prev) # attn + x = self.drop_path1(x) + res # residual + + # Feed forward network + x = x + self.drop_path2(self.mlp(self.norm2(x))) + return x + + +@register_notrace_module +class MobileVitV2Block(nn.Module): + """ + This class defines the `MobileViTv2 block <>`_ + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 3, + bottle_ratio: float = 1.0, + group_size: Optional[int] = 1, + dilation: Tuple[int, int] = (1, 1), + mlp_ratio: float = 2.0, + transformer_dim: Optional[int] = None, + transformer_depth: int = 2, + patch_size: int = 8, + attn_drop: float = 0., + drop: int = 0., + drop_path_rate: float = 0., + layers: LayerFn = None, + transformer_norm_layer: Callable = GroupNorm1, + **kwargs, # eat unused args + ): + super(MobileVitV2Block, self).__init__() + layers = layers or LayerFn() + groups = num_groups(group_size, in_chs) + out_chs = out_chs or in_chs + transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs) + + self.conv_kxk = layers.conv_norm_act( + in_chs, in_chs, kernel_size=kernel_size, + stride=1, groups=groups, dilation=dilation[0]) + self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False) + + self.transformer = nn.Sequential(*[ + LinearTransformerBlock( + transformer_dim, + mlp_ratio=mlp_ratio, + attn_drop=attn_drop, + drop=drop, + drop_path=drop_path_rate, + act_layer=layers.act, + norm_layer=transformer_norm_layer + ) + for _ in range(transformer_depth) + ]) + self.norm = transformer_norm_layer(transformer_dim) + + self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False) + + self.patch_size = to_2tuple(patch_size) + self.patch_area = self.patch_size[0] * self.patch_size[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + patch_h, patch_w = self.patch_size + new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w + num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w + num_patches = num_patch_h * num_patch_w # N + if new_h != H or new_w != W: + x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=True) + + # Local representation + x = self.conv_kxk(x) + x = self.conv_1x1(x) + + # Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N] + C = x.shape[1] + x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4) + x = x.reshape(B, C, -1, num_patches) + + # Global representations + x = self.transformer(x) + x = self.norm(x) + + # Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W] + x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3) + x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w) + + x = self.conv_proj(x) + return x + + +register_block('mobilevit', MobileVitBlock) +register_block('mobilevit2', MobileVitV2Block) def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): @@ -252,6 +600,14 @@ def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): **kwargs) +def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + ByobNet, variant, pretrained, + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + @register_model def mobilevit_xxs(pretrained=False, **kwargs): return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs) @@ -269,4 +625,75 @@ def mobilevit_s(pretrained=False, **kwargs): @register_model def semobilevit_s(pretrained=False, **kwargs): - return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) \ No newline at end of file + return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_050(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_075(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_075', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_100(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_100', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_125(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_125', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_150_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_175_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_200_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_150_384_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_175_384_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_200_384_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs) \ No newline at end of file diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index 17d657b0..a95195b4 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -26,7 +26,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp +from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 from .registry import register_model @@ -80,15 +80,6 @@ class PatchEmbed(nn.Module): return x -class GroupNorm1(nn.GroupNorm): - """ Group Normalization with 1 group. - Input: tensor in shape [B, C, H, W] - """ - - def __init__(self, num_channels, **kwargs): - super().__init__(1, num_channels, **kwargs) - - class Pooling(nn.Module): def __init__(self, pool_size=3): super().__init__()