Merge branch 'master' into cait

pull/609/head
Ross Wightman 3 years ago
commit 5fcddb96a8

@ -48,7 +48,7 @@ parser = argparse.ArgumentParser(description='PyTorch Benchmark')
parser.add_argument('--model-list', metavar='NAME', default='',
help='txt file based list of model names to benchmark')
parser.add_argument('--bench', default='both', type=str,
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'")
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'")
parser.add_argument('--detail', action='store_true', default=False,
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',

@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
torch._C._jit_set_profiling_mode(False)
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*']
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures

@ -15,6 +15,7 @@ from .hrnet import *
from .inception_resnet_v2 import *
from .inception_v3 import *
from .inception_v4 import *
from .mlp_mixer import *
from .mobilenetv3 import *
from .nasnet import *
from .nfnet import *

@ -294,6 +294,8 @@ class SelfAttnBlock(nn.Module):
def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)

@ -62,9 +62,9 @@ class DlaBasic(nn.Module):
self.bn2 = nn.BatchNorm2d(planes)
self.stride = stride
def forward(self, x, residual=None):
if residual is None:
residual = x
def forward(self, x, shortcut=None):
if shortcut is None:
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -73,7 +73,7 @@ class DlaBasic(nn.Module):
out = self.conv2(out)
out = self.bn2(out)
out += residual
out += shortcut
out = self.relu(out)
return out
@ -99,9 +99,9 @@ class DlaBottleneck(nn.Module):
self.bn3 = nn.BatchNorm2d(outplanes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, residual=None):
if residual is None:
residual = x
def forward(self, x, shortcut=None):
if shortcut is None:
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -114,7 +114,7 @@ class DlaBottleneck(nn.Module):
out = self.conv3(out)
out = self.bn3(out)
out += residual
out += shortcut
out = self.relu(out)
return out
@ -154,9 +154,9 @@ class DlaBottle2neck(nn.Module):
self.bn3 = nn.BatchNorm2d(outplanes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, residual=None):
if residual is None:
residual = x
def forward(self, x, shortcut=None):
if shortcut is None:
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -177,26 +177,26 @@ class DlaBottle2neck(nn.Module):
out = self.conv3(out)
out = self.bn3(out)
out += residual
out += shortcut
out = self.relu(out)
return out
class DlaRoot(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, residual):
def __init__(self, in_channels, out_channels, kernel_size, shortcut):
super(DlaRoot, self).__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.residual = residual
self.shortcut = shortcut
def forward(self, *x):
children = x
x = self.conv(torch.cat(x, 1))
x = self.bn(x)
if self.residual:
if self.shortcut:
x += children[0]
x = self.relu(x)
@ -206,7 +206,7 @@ class DlaRoot(nn.Module):
class DlaTree(nn.Module):
def __init__(self, levels, block, in_channels, out_channels, stride=1,
dilation=1, cardinality=1, base_width=64,
level_root=False, root_dim=0, root_kernel_size=1, root_residual=False):
level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False):
super(DlaTree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
@ -226,24 +226,24 @@ class DlaTree(nn.Module):
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels))
else:
cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual))
cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
self.tree1 = DlaTree(
levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs)
self.tree2 = DlaTree(
levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs)
if levels == 1:
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual)
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut)
self.level_root = level_root
self.root_dim = root_dim
self.levels = levels
def forward(self, x, residual=None, children=None):
def forward(self, x, shortcut=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x)
residual = self.project(bottom)
shortcut = self.project(bottom)
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, residual)
x1 = self.tree1(x, shortcut)
if self.levels == 1:
x2 = self.tree2(x1)
x = self.root(x2, x1, *children)
@ -255,7 +255,7 @@ class DlaTree(nn.Module):
class DLA(nn.Module):
def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, block=DlaBottle2neck, residual_root=False,
cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False,
drop_rate=0.0, global_pool='avg'):
super(DLA, self).__init__()
self.channels = channels
@ -271,7 +271,7 @@ class DLA(nn.Module):
nn.ReLU(inplace=True))
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
cargs = dict(cardinality=cardinality, base_width=base_width, root_residual=residual_root)
cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root)
self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs)
self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
@ -413,7 +413,7 @@ def dla60x(pretrained=False, **kwargs): # DLA-X-60
def dla102(pretrained=False, **kwargs): # DLA-102
model_kwargs = dict(
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, residual_root=True, **kwargs)
block=DlaBottleneck, shortcut_root=True, **kwargs)
return _create_dla('dla102', pretrained, **model_kwargs)
@ -421,7 +421,7 @@ def dla102(pretrained=False, **kwargs): # DLA-102
def dla102x(pretrained=False, **kwargs): # DLA-X-102
model_kwargs = dict(
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, **kwargs)
block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs)
return _create_dla('dla102x', pretrained, **model_kwargs)
@ -429,7 +429,7 @@ def dla102x(pretrained=False, **kwargs): # DLA-X-102
def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
model_kwargs = dict(
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, **kwargs)
block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs)
return _create_dla('dla102x2', pretrained, **model_kwargs)
@ -437,5 +437,5 @@ def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
def dla169(pretrained=False, **kwargs): # DLA-169
model_kwargs = dict(
levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, residual_root=True, **kwargs)
block=DlaBottleneck, shortcut_root=True, **kwargs)
return _create_dla('dla169', pretrained, **model_kwargs)

@ -184,7 +184,7 @@ class DepthwiseSeparableConv(nn.Module):
return info
def forward(self, x):
residual = x
shortcut = x
x = self.conv_dw(x)
x = self.bn1(x)
@ -200,7 +200,7 @@ class DepthwiseSeparableConv(nn.Module):
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
x += shortcut
return x
@ -258,7 +258,7 @@ class InvertedResidual(nn.Module):
return info
def forward(self, x):
residual = x
shortcut = x
# Point-wise expansion
x = self.conv_pw(x)
@ -281,7 +281,7 @@ class InvertedResidual(nn.Module):
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
x += shortcut
return x
@ -308,7 +308,7 @@ class CondConvResidual(InvertedResidual):
self.routing_fn = nn.Linear(in_chs, self.num_experts)
def forward(self, x):
residual = x
shortcut = x
# CondConv routing
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
@ -335,7 +335,7 @@ class CondConvResidual(InvertedResidual):
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
x += shortcut
return x
@ -390,7 +390,7 @@ class EdgeResidual(nn.Module):
return info
def forward(self, x):
residual = x
shortcut = x
# Expansion convolution
x = self.conv_exp(x)
@ -408,6 +408,6 @@ class EdgeResidual(nn.Module):
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
x += shortcut
return x

@ -112,7 +112,7 @@ class GhostBottleneck(nn.Module):
def forward(self, x):
residual = x
shortcut = x
# 1st ghost bottleneck
x = self.ghost1(x)
@ -129,7 +129,7 @@ class GhostBottleneck(nn.Module):
# 2nd ghost bottleneck
x = self.ghost2(x)
x += self.shortcut(residual)
x += self.shortcut(shortcut)
return x

@ -1,7 +1,6 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .anti_aliasing import AntiAliasDownsampleLayer
from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer

@ -1,60 +0,0 @@
import torch
import torch.nn.parallel
import torch.nn as nn
import torch.nn.functional as F
class AntiAliasDownsampleLayer(nn.Module):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
super(AntiAliasDownsampleLayer, self).__init__()
if no_jit:
self.op = Downsample(channels, filt_size, stride)
else:
self.op = DownsampleJIT(channels, filt_size, stride)
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
def forward(self, x):
return self.op(x)
@torch.jit.script
class DownsampleJIT(object):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
self.channels = channels
self.stride = stride
self.filt_size = filt_size
assert self.filt_size == 3
assert stride == 2
self.filt = {} # lazy init by device for DataParallel compat
def _create_filter(self, like: torch.Tensor):
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt)
return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
def __call__(self, input: torch.Tensor):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
filt = self.filt.get(str(input.device), self._create_filter(input))
return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1])
class Downsample(nn.Module):
def __init__(self, channels=None, filt_size=3, stride=2):
super(Downsample, self).__init__()
self.channels = channels
self.filt_size = filt_size
self.stride = stride
assert self.filt_size == 3
filt = torch.tensor([1., 2., 1.])
filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt)
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
def forward(self, input):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])

@ -3,8 +3,6 @@ BlurPool layer inspired by
- Kornia's Max_BlurPool2d
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
FIXME merge this impl with those in `anti_aliasing.py`
Hacked together by Chris Ha and Ross Wightman
"""
@ -12,7 +10,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict
from .padding import get_padding
@ -29,30 +26,17 @@ class BlurPool2d(nn.Module):
Returns:
torch.Tensor: the transformed tensor.
"""
filt: Dict[str, torch.Tensor]
def __init__(self, channels, filt_size=3, stride=2) -> None:
super(BlurPool2d, self).__init__()
assert filt_size > 1
self.channels = channels
self.filt_size = filt_size
self.stride = stride
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4
self.padding = nn.ReflectionPad2d(pad_size)
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat
self.filt = {} # lazy init by device for DataParallel compat
def _create_filter(self, like: torch.Tensor):
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device)
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1)
def _apply(self, fn):
# override nn.Module _apply, reset filter cache if used
self.filt = {}
super(BlurPool2d, self)._apply(fn)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
C = input_tensor.shape[1]
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
return F.conv2d(
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
self.register_buffer('filt', blur_filter, persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding, 'reflect')
return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1])

@ -21,6 +21,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .helpers import to_2tuple
from .weight_init import trunc_normal_
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -101,6 +102,11 @@ class BottleneckAttn(nn.Module):
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.pos_embed.height and W == self.pos_embed.width

@ -25,6 +25,8 @@ import torch
from torch import nn
import torch.nn.functional as F
from .weight_init import trunc_normal_
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
""" Compute relative logits along one dimension
@ -124,6 +126,13 @@ class HaloAttn(nn.Module):
self.pos_embed = PosEmbedRel(
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
def reset_parameters(self):
std = self.q.weight.shape[1] ** -0.5 # fan-in
trunc_normal_(self.q.weight, std=std)
trunc_normal_(self.kv.weight, std=std)
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
def forward(self, x):
B, C, H, W = x.shape
assert H % self.block_size == 0 and W % self.block_size == 0

@ -24,6 +24,7 @@ import torch
from torch import nn
import torch.nn.functional as F
from .weight_init import trunc_normal_
class LambdaLayer(nn.Module):
@ -36,6 +37,7 @@ class LambdaLayer(nn.Module):
self,
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
super().__init__()
self.dim = dim
self.dim_out = dim_out or dim
self.dim_k = dim_head # query depth 'k'
self.num_heads = num_heads
@ -55,6 +57,10 @@ class LambdaLayer(nn.Module):
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)
def forward(self, x):
B, C, H, W = x.shape
M = H * W

@ -107,6 +107,7 @@ class WindowAttention(nn.Module):
self.relative_position_bias_table = nn.Parameter(
# 2 * Wh - 1 * 2 * Ww - 1, nH
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
trunc_normal_(self.relative_position_bias_table, std=.02)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.win_size)
@ -120,13 +121,16 @@ class WindowAttention(nn.Module):
relative_coords[:, :, 0] *= 2 * self.win_size - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.softmax = nn.Softmax(dim=-1)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self, x):
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1)

@ -0,0 +1,292 @@
""" MLP-Mixer in PyTorch
Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
@article{tolstikhin2021,
title={MLP-Mixer: An all-MLP Architecture for Vision},
author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner,
Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
journal={arXiv preprint arXiv:2105.01601},
year={2021}
}
A thank you to paper authors for releasing code and weights.
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
from copy import deepcopy
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import DropPath, to_2tuple, lecun_normal_
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'stem.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = dict(
mixer_s32_224=_cfg(),
mixer_s16_224=_cfg(),
mixer_b32_224=_cfg(),
mixer_b16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
),
mixer_b16_224_in21k=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
num_classes=21843
),
mixer_l32_224=_cfg(),
mixer_l16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
),
mixer_l16_224_in21k=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
num_classes=21843
),
)
class Mlp(nn.Module):
""" MLP Block
NOTE: same impl as ViT, move to common location
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
NOTE: same impl as ViT, move to common location
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.patch_grid[0] * self.patch_grid[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class MixerBlock(nn.Module):
def __init__(
self, dim, seq_len, tokens_dim, channels_dim,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
super().__init__()
self.norm1 = norm_layer(dim)
self.mlp_tokens = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
return x
class MlpMixer(nn.Module):
def __init__(
self,
num_classes=1000,
img_size=224,
in_chans=3,
patch_size=16,
num_blocks=8,
hidden_dim=512,
tokens_dim=256,
channels_dim=2048,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop=0.,
drop_path=0.,
nlhb=False,
):
super().__init__()
self.num_classes = num_classes
self.stem = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim)
# FIXME drop_path (stochastic depth scaling rule?)
self.blocks = nn.Sequential(*[
MixerBlock(
hidden_dim, self.stem.num_patches, tokens_dim, channels_dim,
norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path)
for _ in range(num_blocks)])
self.norm = norm_layer(hidden_dim)
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
self.init_weights(nlhb=nlhb)
def init_weights(self, nlhb=False):
head_bias = -math.log(self.num_classes) if nlhb else 0.
for n, m in self.named_modules():
_init_weights(m, n, head_bias=head_bias)
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.norm(x)
x = x.mean(dim=1)
x = self.head(x)
return x
def _init_weights(m, n: str, head_bias: float = 0.):
""" Mixer weight initialization (trying to match Flax defaults)
"""
if isinstance(m, nn.Linear):
if n.startswith('head'):
nn.init.zeros_(m.weight)
nn.init.constant_(m.bias, head_bias)
else:
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
if 'mlp' in n:
nn.init.normal_(m.bias, std=1e-6)
else:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
model = build_model_with_cfg(
MlpMixer, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
**kwargs)
return model
@register_model
def mixer_s32_224(pretrained=False, **kwargs):
""" Mixer-S/32 224x224
"""
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_s16_224(pretrained=False, **kwargs):
""" Mixer-S/16 224x224
"""
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b32_224(pretrained=False, **kwargs):
""" Mixer-B/32 224x224
"""
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224_in21k(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l32_224(pretrained=False, **kwargs):
""" Mixer-L/32 224x224.
"""
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l16_224(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l16_224_in21k(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
return model

@ -91,7 +91,7 @@ class Bottle2neck(nn.Module):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -124,9 +124,9 @@ class Bottle2neck(nn.Module):
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
out += residual
out += shortcut
out = self.relu(out)
return out

@ -105,7 +105,7 @@ class ResNestBottleneck(nn.Module):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -132,9 +132,9 @@ class ResNestBottleneck(nn.Module):
out = self.drop_block(out)
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
out += residual
out += shortcut
out = self.act3(out)
return out

@ -241,31 +241,31 @@ default_cfgs = {
# ResNet-RS models
'resnetrs50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50-7c9728e2.pth',
input_size=(3, 160, 160), pool_size=(4, 4), crop_pct=0.91, test_input_size=(3, 224, 224),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50_ema-6b53758b.pth',
input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.91, test_input_size=(3, 224, 224),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs101': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101-3e4bb55c.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101_i192_ema-1509bbf6.pth',
input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs152': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152-b1efe56d.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152_i256_ema-a9aff7f9.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs200': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200-b455b791.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200_ema-623d2f59.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs270': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270-cafcfbc7.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270_ema-b40e674c.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 352, 352),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs350': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350-06d9bfac.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350_i256_ema-5a1aa8f1.pth',
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs420': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420-d26764a5.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420_ema-972dee69.pth',
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416),
interpolation='bicubic', first_conv='conv1.0'),
}
@ -315,7 +315,7 @@ class BasicBlock(nn.Module):
nn.init.zeros_(self.bn2.weight)
def forward(self, x):
residual = x
shortcut = x
x = self.conv1(x)
x = self.bn1(x)
@ -337,8 +337,8 @@ class BasicBlock(nn.Module):
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
shortcut = self.downsample(shortcut)
x += shortcut
x = self.act2(x)
return x
@ -385,7 +385,7 @@ class Bottleneck(nn.Module):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x
shortcut = x
x = self.conv1(x)
x = self.bn1(x)
@ -413,8 +413,8 @@ class Bottleneck(nn.Module):
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
shortcut = self.downsample(shortcut)
x += shortcut
x = self.act3(x)
return x

@ -92,7 +92,7 @@ class Bottleneck(nn.Module):
"""
def forward(self, x):
residual = x
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -106,9 +106,9 @@ class Bottleneck(nn.Module):
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
out = self.se_module(out) + residual
out = self.se_module(out) + shortcut
out = self.relu(out)
return out
@ -204,7 +204,7 @@ class SEResNetBlock(nn.Module):
self.stride = stride
def forward(self, x):
residual = x
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -215,9 +215,9 @@ class SEResNetBlock(nn.Module):
out = self.relu(out)
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
out = self.se_module(out) + residual
out = self.se_module(out) + shortcut
out = self.relu(out)
return out

@ -76,7 +76,7 @@ class SelectiveKernelBasic(nn.Module):
nn.init.zeros_(self.conv2.bn.weight)
def forward(self, x):
residual = x
shortcut = x
x = self.conv1(x)
x = self.conv2(x)
if self.se is not None:
@ -84,8 +84,8 @@ class SelectiveKernelBasic(nn.Module):
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
shortcut = self.downsample(shortcut)
x += shortcut
x = self.act(x)
return x
@ -124,7 +124,7 @@ class SelectiveKernelBottleneck(nn.Module):
nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x):
residual = x
shortcut = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
@ -133,8 +133,8 @@ class SelectiveKernelBottleneck(nn.Module):
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
shortcut = self.downsample(shortcut)
x += shortcut
x = self.act(x)
return x

@ -5,16 +5,13 @@ https://arxiv.org/pdf/2003.13630.pdf
Original model: https://github.com/mrT23/TResNet
"""
import copy
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule
from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule
from .registry import register_model
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
@ -92,9 +89,9 @@ class BasicBlock(nn.Module):
def forward(self, x):
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
else:
residual = x
shortcut = x
out = self.conv1(x)
out = self.conv2(out)
@ -102,7 +99,7 @@ class BasicBlock(nn.Module):
if self.se is not None:
out = self.se(out)
out += residual
out += shortcut
out = self.relu(out)
return out
@ -139,9 +136,9 @@ class Bottleneck(nn.Module):
def forward(self, x):
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
else:
residual = x
shortcut = x
out = self.conv1(x)
out = self.conv2(out)
@ -149,22 +146,19 @@ class Bottleneck(nn.Module):
out = self.se(out)
out = self.conv3(out)
out = out + residual # no inplace
out = out + shortcut # no inplace
out = self.relu(out)
return out
class TResNet(nn.Module):
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
global_pool='fast', drop_rate=0.):
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, global_pool='fast', drop_rate=0.):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(TResNet, self).__init__()
# JIT layers
space_to_depth = SpaceToDepthModule()
aa_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit)
aa_layer = BlurPool2d
# TResnet stages
self.inplanes = int(64 * width_factor)
@ -181,7 +175,7 @@ class TResNet(nn.Module):
# body
self.body = nn.Sequential(OrderedDict([
('SpaceToDepth', space_to_depth),
('SpaceToDepth', SpaceToDepthModule()),
('conv1', conv1),
('layer1', layer1),
('layer2', layer2),

Loading…
Cancel
Save