Merge branch 'master' into cait

pull/609/head
Ross Wightman 3 years ago
commit 5fcddb96a8

@ -48,7 +48,7 @@ parser = argparse.ArgumentParser(description='PyTorch Benchmark')
parser.add_argument('--model-list', metavar='NAME', default='', parser.add_argument('--model-list', metavar='NAME', default='',
help='txt file based list of model names to benchmark') help='txt file based list of model names to benchmark')
parser.add_argument('--bench', default='both', type=str, parser.add_argument('--bench', default='both', type=str,
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'") help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'")
parser.add_argument('--detail', action='store_true', default=False, parser.add_argument('--detail', action='store_true', default=False,
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',

@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
# transformer models don't support many of the spatial / feature based model functionalities # transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*'] NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*']
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures # exclude models that cause specific test failures

@ -15,6 +15,7 @@ from .hrnet import *
from .inception_resnet_v2 import * from .inception_resnet_v2 import *
from .inception_v3 import * from .inception_v3 import *
from .inception_v4 import * from .inception_v4 import *
from .mlp_mixer import *
from .mobilenetv3 import * from .mobilenetv3 import *
from .nasnet import * from .nasnet import *
from .nfnet import * from .nfnet import *

@ -294,6 +294,8 @@ class SelfAttnBlock(nn.Module):
def init_weights(self, zero_init_last_bn=False): def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn: if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight) nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
def forward(self, x): def forward(self, x):
shortcut = self.shortcut(x) shortcut = self.shortcut(x)

@ -62,9 +62,9 @@ class DlaBasic(nn.Module):
self.bn2 = nn.BatchNorm2d(planes) self.bn2 = nn.BatchNorm2d(planes)
self.stride = stride self.stride = stride
def forward(self, x, residual=None): def forward(self, x, shortcut=None):
if residual is None: if shortcut is None:
residual = x shortcut = x
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
@ -73,7 +73,7 @@ class DlaBasic(nn.Module):
out = self.conv2(out) out = self.conv2(out)
out = self.bn2(out) out = self.bn2(out)
out += residual out += shortcut
out = self.relu(out) out = self.relu(out)
return out return out
@ -99,9 +99,9 @@ class DlaBottleneck(nn.Module):
self.bn3 = nn.BatchNorm2d(outplanes) self.bn3 = nn.BatchNorm2d(outplanes)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
def forward(self, x, residual=None): def forward(self, x, shortcut=None):
if residual is None: if shortcut is None:
residual = x shortcut = x
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
@ -114,7 +114,7 @@ class DlaBottleneck(nn.Module):
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
out += residual out += shortcut
out = self.relu(out) out = self.relu(out)
return out return out
@ -154,9 +154,9 @@ class DlaBottle2neck(nn.Module):
self.bn3 = nn.BatchNorm2d(outplanes) self.bn3 = nn.BatchNorm2d(outplanes)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
def forward(self, x, residual=None): def forward(self, x, shortcut=None):
if residual is None: if shortcut is None:
residual = x shortcut = x
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
@ -177,26 +177,26 @@ class DlaBottle2neck(nn.Module):
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
out += residual out += shortcut
out = self.relu(out) out = self.relu(out)
return out return out
class DlaRoot(nn.Module): class DlaRoot(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, residual): def __init__(self, in_channels, out_channels, kernel_size, shortcut):
super(DlaRoot, self).__init__() super(DlaRoot, self).__init__()
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2) in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2)
self.bn = nn.BatchNorm2d(out_channels) self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.residual = residual self.shortcut = shortcut
def forward(self, *x): def forward(self, *x):
children = x children = x
x = self.conv(torch.cat(x, 1)) x = self.conv(torch.cat(x, 1))
x = self.bn(x) x = self.bn(x)
if self.residual: if self.shortcut:
x += children[0] x += children[0]
x = self.relu(x) x = self.relu(x)
@ -206,7 +206,7 @@ class DlaRoot(nn.Module):
class DlaTree(nn.Module): class DlaTree(nn.Module):
def __init__(self, levels, block, in_channels, out_channels, stride=1, def __init__(self, levels, block, in_channels, out_channels, stride=1,
dilation=1, cardinality=1, base_width=64, dilation=1, cardinality=1, base_width=64,
level_root=False, root_dim=0, root_kernel_size=1, root_residual=False): level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False):
super(DlaTree, self).__init__() super(DlaTree, self).__init__()
if root_dim == 0: if root_dim == 0:
root_dim = 2 * out_channels root_dim = 2 * out_channels
@ -226,24 +226,24 @@ class DlaTree(nn.Module):
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels)) nn.BatchNorm2d(out_channels))
else: else:
cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual)) cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
self.tree1 = DlaTree( self.tree1 = DlaTree(
levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs) levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs)
self.tree2 = DlaTree( self.tree2 = DlaTree(
levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs) levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs)
if levels == 1: if levels == 1:
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual) self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut)
self.level_root = level_root self.level_root = level_root
self.root_dim = root_dim self.root_dim = root_dim
self.levels = levels self.levels = levels
def forward(self, x, residual=None, children=None): def forward(self, x, shortcut=None, children=None):
children = [] if children is None else children children = [] if children is None else children
bottom = self.downsample(x) bottom = self.downsample(x)
residual = self.project(bottom) shortcut = self.project(bottom)
if self.level_root: if self.level_root:
children.append(bottom) children.append(bottom)
x1 = self.tree1(x, residual) x1 = self.tree1(x, shortcut)
if self.levels == 1: if self.levels == 1:
x2 = self.tree2(x1) x2 = self.tree2(x1)
x = self.root(x2, x1, *children) x = self.root(x2, x1, *children)
@ -255,7 +255,7 @@ class DlaTree(nn.Module):
class DLA(nn.Module): class DLA(nn.Module):
def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, block=DlaBottle2neck, residual_root=False, cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False,
drop_rate=0.0, global_pool='avg'): drop_rate=0.0, global_pool='avg'):
super(DLA, self).__init__() super(DLA, self).__init__()
self.channels = channels self.channels = channels
@ -271,7 +271,7 @@ class DLA(nn.Module):
nn.ReLU(inplace=True)) nn.ReLU(inplace=True))
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
cargs = dict(cardinality=cardinality, base_width=base_width, root_residual=residual_root) cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root)
self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs) self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs)
self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs) self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs) self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
@ -413,7 +413,7 @@ def dla60x(pretrained=False, **kwargs): # DLA-X-60
def dla102(pretrained=False, **kwargs): # DLA-102 def dla102(pretrained=False, **kwargs): # DLA-102
model_kwargs = dict( model_kwargs = dict(
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, residual_root=True, **kwargs) block=DlaBottleneck, shortcut_root=True, **kwargs)
return _create_dla('dla102', pretrained, **model_kwargs) return _create_dla('dla102', pretrained, **model_kwargs)
@ -421,7 +421,7 @@ def dla102(pretrained=False, **kwargs): # DLA-102
def dla102x(pretrained=False, **kwargs): # DLA-X-102 def dla102x(pretrained=False, **kwargs): # DLA-X-102
model_kwargs = dict( model_kwargs = dict(
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, **kwargs) block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs)
return _create_dla('dla102x', pretrained, **model_kwargs) return _create_dla('dla102x', pretrained, **model_kwargs)
@ -429,7 +429,7 @@ def dla102x(pretrained=False, **kwargs): # DLA-X-102
def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64 def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
model_kwargs = dict( model_kwargs = dict(
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, **kwargs) block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs)
return _create_dla('dla102x2', pretrained, **model_kwargs) return _create_dla('dla102x2', pretrained, **model_kwargs)
@ -437,5 +437,5 @@ def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
def dla169(pretrained=False, **kwargs): # DLA-169 def dla169(pretrained=False, **kwargs): # DLA-169
model_kwargs = dict( model_kwargs = dict(
levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024], levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, residual_root=True, **kwargs) block=DlaBottleneck, shortcut_root=True, **kwargs)
return _create_dla('dla169', pretrained, **model_kwargs) return _create_dla('dla169', pretrained, **model_kwargs)

@ -184,7 +184,7 @@ class DepthwiseSeparableConv(nn.Module):
return info return info
def forward(self, x): def forward(self, x):
residual = x shortcut = x
x = self.conv_dw(x) x = self.conv_dw(x)
x = self.bn1(x) x = self.bn1(x)
@ -200,7 +200,7 @@ class DepthwiseSeparableConv(nn.Module):
if self.has_residual: if self.has_residual:
if self.drop_path_rate > 0.: if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training) x = drop_path(x, self.drop_path_rate, self.training)
x += residual x += shortcut
return x return x
@ -258,7 +258,7 @@ class InvertedResidual(nn.Module):
return info return info
def forward(self, x): def forward(self, x):
residual = x shortcut = x
# Point-wise expansion # Point-wise expansion
x = self.conv_pw(x) x = self.conv_pw(x)
@ -281,7 +281,7 @@ class InvertedResidual(nn.Module):
if self.has_residual: if self.has_residual:
if self.drop_path_rate > 0.: if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training) x = drop_path(x, self.drop_path_rate, self.training)
x += residual x += shortcut
return x return x
@ -308,7 +308,7 @@ class CondConvResidual(InvertedResidual):
self.routing_fn = nn.Linear(in_chs, self.num_experts) self.routing_fn = nn.Linear(in_chs, self.num_experts)
def forward(self, x): def forward(self, x):
residual = x shortcut = x
# CondConv routing # CondConv routing
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
@ -335,7 +335,7 @@ class CondConvResidual(InvertedResidual):
if self.has_residual: if self.has_residual:
if self.drop_path_rate > 0.: if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training) x = drop_path(x, self.drop_path_rate, self.training)
x += residual x += shortcut
return x return x
@ -390,7 +390,7 @@ class EdgeResidual(nn.Module):
return info return info
def forward(self, x): def forward(self, x):
residual = x shortcut = x
# Expansion convolution # Expansion convolution
x = self.conv_exp(x) x = self.conv_exp(x)
@ -408,6 +408,6 @@ class EdgeResidual(nn.Module):
if self.has_residual: if self.has_residual:
if self.drop_path_rate > 0.: if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training) x = drop_path(x, self.drop_path_rate, self.training)
x += residual x += shortcut
return x return x

@ -112,7 +112,7 @@ class GhostBottleneck(nn.Module):
def forward(self, x): def forward(self, x):
residual = x shortcut = x
# 1st ghost bottleneck # 1st ghost bottleneck
x = self.ghost1(x) x = self.ghost1(x)
@ -129,7 +129,7 @@ class GhostBottleneck(nn.Module):
# 2nd ghost bottleneck # 2nd ghost bottleneck
x = self.ghost2(x) x = self.ghost2(x)
x += self.shortcut(residual) x += self.shortcut(shortcut)
return x return x

@ -1,7 +1,6 @@
from .activations import * from .activations import *
from .adaptive_avgmax_pool import \ from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .anti_aliasing import AntiAliasDownsampleLayer
from .blur_pool import BlurPool2d from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer from .cond_conv2d import CondConv2d, get_condconv_initializer

@ -1,60 +0,0 @@
import torch
import torch.nn.parallel
import torch.nn as nn
import torch.nn.functional as F
class AntiAliasDownsampleLayer(nn.Module):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
super(AntiAliasDownsampleLayer, self).__init__()
if no_jit:
self.op = Downsample(channels, filt_size, stride)
else:
self.op = DownsampleJIT(channels, filt_size, stride)
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
def forward(self, x):
return self.op(x)
@torch.jit.script
class DownsampleJIT(object):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
self.channels = channels
self.stride = stride
self.filt_size = filt_size
assert self.filt_size == 3
assert stride == 2
self.filt = {} # lazy init by device for DataParallel compat
def _create_filter(self, like: torch.Tensor):
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt)
return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
def __call__(self, input: torch.Tensor):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
filt = self.filt.get(str(input.device), self._create_filter(input))
return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1])
class Downsample(nn.Module):
def __init__(self, channels=None, filt_size=3, stride=2):
super(Downsample, self).__init__()
self.channels = channels
self.filt_size = filt_size
self.stride = stride
assert self.filt_size == 3
filt = torch.tensor([1., 2., 1.])
filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt)
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
def forward(self, input):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])

@ -3,8 +3,6 @@ BlurPool layer inspired by
- Kornia's Max_BlurPool2d - Kornia's Max_BlurPool2d
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
FIXME merge this impl with those in `anti_aliasing.py`
Hacked together by Chris Ha and Ross Wightman Hacked together by Chris Ha and Ross Wightman
""" """
@ -12,7 +10,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from typing import Dict
from .padding import get_padding from .padding import get_padding
@ -29,30 +26,17 @@ class BlurPool2d(nn.Module):
Returns: Returns:
torch.Tensor: the transformed tensor. torch.Tensor: the transformed tensor.
""" """
filt: Dict[str, torch.Tensor]
def __init__(self, channels, filt_size=3, stride=2) -> None: def __init__(self, channels, filt_size=3, stride=2) -> None:
super(BlurPool2d, self).__init__() super(BlurPool2d, self).__init__()
assert filt_size > 1 assert filt_size > 1
self.channels = channels self.channels = channels
self.filt_size = filt_size self.filt_size = filt_size
self.stride = stride self.stride = stride
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4 self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
self.padding = nn.ReflectionPad2d(pad_size) coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
self.filt = {} # lazy init by device for DataParallel compat self.register_buffer('filt', blur_filter, persistent=False)
def _create_filter(self, like: torch.Tensor): def forward(self, x: torch.Tensor) -> torch.Tensor:
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device) x = F.pad(x, self.padding, 'reflect')
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1) return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1])
def _apply(self, fn):
# override nn.Module _apply, reset filter cache if used
self.filt = {}
super(BlurPool2d, self)._apply(fn)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
C = input_tensor.shape[1]
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
return F.conv2d(
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)

@ -21,6 +21,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import to_2tuple from .helpers import to_2tuple
from .weight_init import trunc_normal_
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -101,6 +102,11 @@ class BottleneckAttn(nn.Module):
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.pos_embed.height and W == self.pos_embed.width assert H == self.pos_embed.height and W == self.pos_embed.width

@ -25,6 +25,8 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from .weight_init import trunc_normal_
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):
""" Compute relative logits along one dimension """ Compute relative logits along one dimension
@ -124,6 +126,13 @@ class HaloAttn(nn.Module):
self.pos_embed = PosEmbedRel( self.pos_embed = PosEmbedRel(
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
def reset_parameters(self):
std = self.q.weight.shape[1] ** -0.5 # fan-in
trunc_normal_(self.q.weight, std=std)
trunc_normal_(self.kv.weight, std=std)
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H % self.block_size == 0 and W % self.block_size == 0 assert H % self.block_size == 0 and W % self.block_size == 0

@ -24,6 +24,7 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from .weight_init import trunc_normal_
class LambdaLayer(nn.Module): class LambdaLayer(nn.Module):
@ -36,6 +37,7 @@ class LambdaLayer(nn.Module):
self, self,
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
super().__init__() super().__init__()
self.dim = dim
self.dim_out = dim_out or dim self.dim_out = dim_out or dim
self.dim_k = dim_head # query depth 'k' self.dim_k = dim_head # query depth 'k'
self.num_heads = num_heads self.num_heads = num_heads
@ -55,6 +57,10 @@ class LambdaLayer(nn.Module):
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
M = H * W M = H * W

@ -107,6 +107,7 @@ class WindowAttention(nn.Module):
self.relative_position_bias_table = nn.Parameter( self.relative_position_bias_table = nn.Parameter(
# 2 * Wh - 1 * 2 * Ww - 1, nH # 2 * Wh - 1 * 2 * Ww - 1, nH
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads)) torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
trunc_normal_(self.relative_position_bias_table, std=.02)
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.win_size) coords_h = torch.arange(self.win_size)
@ -120,13 +121,16 @@ class WindowAttention(nn.Module):
relative_coords[:, :, 0] *= 2 * self.win_size - 1 relative_coords[:, :, 0] *= 2 * self.win_size - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index) self.register_buffer("relative_position_index", relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.softmax = nn.Softmax(dim=-1) self.softmax = nn.Softmax(dim=-1)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1) x = x.permute(0, 2, 3, 1)

@ -0,0 +1,292 @@
""" MLP-Mixer in PyTorch
Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
@article{tolstikhin2021,
title={MLP-Mixer: An all-MLP Architecture for Vision},
author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner,
Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
journal={arXiv preprint arXiv:2105.01601},
year={2021}
}
A thank you to paper authors for releasing code and weights.
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
from copy import deepcopy
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import DropPath, to_2tuple, lecun_normal_
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'stem.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = dict(
mixer_s32_224=_cfg(),
mixer_s16_224=_cfg(),
mixer_b32_224=_cfg(),
mixer_b16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
),
mixer_b16_224_in21k=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
num_classes=21843
),
mixer_l32_224=_cfg(),
mixer_l16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
),
mixer_l16_224_in21k=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
num_classes=21843
),
)
class Mlp(nn.Module):
""" MLP Block
NOTE: same impl as ViT, move to common location
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
NOTE: same impl as ViT, move to common location
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.patch_grid[0] * self.patch_grid[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class MixerBlock(nn.Module):
def __init__(
self, dim, seq_len, tokens_dim, channels_dim,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
super().__init__()
self.norm1 = norm_layer(dim)
self.mlp_tokens = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
return x
class MlpMixer(nn.Module):
def __init__(
self,
num_classes=1000,
img_size=224,
in_chans=3,
patch_size=16,
num_blocks=8,
hidden_dim=512,
tokens_dim=256,
channels_dim=2048,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop=0.,
drop_path=0.,
nlhb=False,
):
super().__init__()
self.num_classes = num_classes
self.stem = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim)
# FIXME drop_path (stochastic depth scaling rule?)
self.blocks = nn.Sequential(*[
MixerBlock(
hidden_dim, self.stem.num_patches, tokens_dim, channels_dim,
norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path)
for _ in range(num_blocks)])
self.norm = norm_layer(hidden_dim)
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
self.init_weights(nlhb=nlhb)
def init_weights(self, nlhb=False):
head_bias = -math.log(self.num_classes) if nlhb else 0.
for n, m in self.named_modules():
_init_weights(m, n, head_bias=head_bias)
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.norm(x)
x = x.mean(dim=1)
x = self.head(x)
return x
def _init_weights(m, n: str, head_bias: float = 0.):
""" Mixer weight initialization (trying to match Flax defaults)
"""
if isinstance(m, nn.Linear):
if n.startswith('head'):
nn.init.zeros_(m.weight)
nn.init.constant_(m.bias, head_bias)
else:
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
if 'mlp' in n:
nn.init.normal_(m.bias, std=1e-6)
else:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
model = build_model_with_cfg(
MlpMixer, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
**kwargs)
return model
@register_model
def mixer_s32_224(pretrained=False, **kwargs):
""" Mixer-S/32 224x224
"""
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_s16_224(pretrained=False, **kwargs):
""" Mixer-S/16 224x224
"""
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b32_224(pretrained=False, **kwargs):
""" Mixer-B/32 224x224
"""
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224_in21k(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l32_224(pretrained=False, **kwargs):
""" Mixer-L/32 224x224.
"""
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l16_224(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l16_224_in21k(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
return model

@ -91,7 +91,7 @@ class Bottle2neck(nn.Module):
nn.init.zeros_(self.bn3.weight) nn.init.zeros_(self.bn3.weight)
def forward(self, x): def forward(self, x):
residual = x shortcut = x
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
@ -124,9 +124,9 @@ class Bottle2neck(nn.Module):
out = self.se(out) out = self.se(out)
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(x) shortcut = self.downsample(x)
out += residual out += shortcut
out = self.relu(out) out = self.relu(out)
return out return out

@ -105,7 +105,7 @@ class ResNestBottleneck(nn.Module):
nn.init.zeros_(self.bn3.weight) nn.init.zeros_(self.bn3.weight)
def forward(self, x): def forward(self, x):
residual = x shortcut = x
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
@ -132,9 +132,9 @@ class ResNestBottleneck(nn.Module):
out = self.drop_block(out) out = self.drop_block(out)
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(x) shortcut = self.downsample(x)
out += residual out += shortcut
out = self.act3(out) out = self.act3(out)
return out return out

@ -241,31 +241,31 @@ default_cfgs = {
# ResNet-RS models # ResNet-RS models
'resnetrs50': _cfg( 'resnetrs50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50-7c9728e2.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50_ema-6b53758b.pth',
input_size=(3, 160, 160), pool_size=(4, 4), crop_pct=0.91, test_input_size=(3, 224, 224), input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.91, test_input_size=(3, 224, 224),
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnetrs101': _cfg( 'resnetrs101': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101-3e4bb55c.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101_i192_ema-1509bbf6.pth',
input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288), input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288),
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnetrs152': _cfg( 'resnetrs152': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152-b1efe56d.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152_i256_ema-a9aff7f9.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnetrs200': _cfg( 'resnetrs200': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200-b455b791.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200_ema-623d2f59.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnetrs270': _cfg( 'resnetrs270': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270-cafcfbc7.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270_ema-b40e674c.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 352, 352),
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnetrs350': _cfg( 'resnetrs350': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350-06d9bfac.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350_i256_ema-5a1aa8f1.pth',
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384), input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384),
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnetrs420': _cfg( 'resnetrs420': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420-d26764a5.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420_ema-972dee69.pth',
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416), input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416),
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
} }
@ -315,7 +315,7 @@ class BasicBlock(nn.Module):
nn.init.zeros_(self.bn2.weight) nn.init.zeros_(self.bn2.weight)
def forward(self, x): def forward(self, x):
residual = x shortcut = x
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
@ -337,8 +337,8 @@ class BasicBlock(nn.Module):
x = self.drop_path(x) x = self.drop_path(x)
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(residual) shortcut = self.downsample(shortcut)
x += residual x += shortcut
x = self.act2(x) x = self.act2(x)
return x return x
@ -385,7 +385,7 @@ class Bottleneck(nn.Module):
nn.init.zeros_(self.bn3.weight) nn.init.zeros_(self.bn3.weight)
def forward(self, x): def forward(self, x):
residual = x shortcut = x
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
@ -413,8 +413,8 @@ class Bottleneck(nn.Module):
x = self.drop_path(x) x = self.drop_path(x)
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(residual) shortcut = self.downsample(shortcut)
x += residual x += shortcut
x = self.act3(x) x = self.act3(x)
return x return x

@ -92,7 +92,7 @@ class Bottleneck(nn.Module):
""" """
def forward(self, x): def forward(self, x):
residual = x shortcut = x
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
@ -106,9 +106,9 @@ class Bottleneck(nn.Module):
out = self.bn3(out) out = self.bn3(out)
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(x) shortcut = self.downsample(x)
out = self.se_module(out) + residual out = self.se_module(out) + shortcut
out = self.relu(out) out = self.relu(out)
return out return out
@ -204,7 +204,7 @@ class SEResNetBlock(nn.Module):
self.stride = stride self.stride = stride
def forward(self, x): def forward(self, x):
residual = x shortcut = x
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
@ -215,9 +215,9 @@ class SEResNetBlock(nn.Module):
out = self.relu(out) out = self.relu(out)
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(x) shortcut = self.downsample(x)
out = self.se_module(out) + residual out = self.se_module(out) + shortcut
out = self.relu(out) out = self.relu(out)
return out return out

@ -76,7 +76,7 @@ class SelectiveKernelBasic(nn.Module):
nn.init.zeros_(self.conv2.bn.weight) nn.init.zeros_(self.conv2.bn.weight)
def forward(self, x): def forward(self, x):
residual = x shortcut = x
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
if self.se is not None: if self.se is not None:
@ -84,8 +84,8 @@ class SelectiveKernelBasic(nn.Module):
if self.drop_path is not None: if self.drop_path is not None:
x = self.drop_path(x) x = self.drop_path(x)
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(residual) shortcut = self.downsample(shortcut)
x += residual x += shortcut
x = self.act(x) x = self.act(x)
return x return x
@ -124,7 +124,7 @@ class SelectiveKernelBottleneck(nn.Module):
nn.init.zeros_(self.conv3.bn.weight) nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x): def forward(self, x):
residual = x shortcut = x
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.conv3(x) x = self.conv3(x)
@ -133,8 +133,8 @@ class SelectiveKernelBottleneck(nn.Module):
if self.drop_path is not None: if self.drop_path is not None:
x = self.drop_path(x) x = self.drop_path(x)
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(residual) shortcut = self.downsample(shortcut)
x += residual x += shortcut
x = self.act(x) x = self.act(x)
return x return x

@ -5,16 +5,13 @@ https://arxiv.org/pdf/2003.13630.pdf
Original model: https://github.com/mrT23/TResNet Original model: https://github.com/mrT23/TResNet
""" """
import copy
from collections import OrderedDict from collections import OrderedDict
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule
from .registry import register_model from .registry import register_model
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
@ -92,9 +89,9 @@ class BasicBlock(nn.Module):
def forward(self, x): def forward(self, x):
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(x) shortcut = self.downsample(x)
else: else:
residual = x shortcut = x
out = self.conv1(x) out = self.conv1(x)
out = self.conv2(out) out = self.conv2(out)
@ -102,7 +99,7 @@ class BasicBlock(nn.Module):
if self.se is not None: if self.se is not None:
out = self.se(out) out = self.se(out)
out += residual out += shortcut
out = self.relu(out) out = self.relu(out)
return out return out
@ -139,9 +136,9 @@ class Bottleneck(nn.Module):
def forward(self, x): def forward(self, x):
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(x) shortcut = self.downsample(x)
else: else:
residual = x shortcut = x
out = self.conv1(x) out = self.conv1(x)
out = self.conv2(out) out = self.conv2(out)
@ -149,22 +146,19 @@ class Bottleneck(nn.Module):
out = self.se(out) out = self.se(out)
out = self.conv3(out) out = self.conv3(out)
out = out + residual # no inplace out = out + shortcut # no inplace
out = self.relu(out) out = self.relu(out)
return out return out
class TResNet(nn.Module): class TResNet(nn.Module):
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False, def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, global_pool='fast', drop_rate=0.):
global_pool='fast', drop_rate=0.):
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
super(TResNet, self).__init__() super(TResNet, self).__init__()
# JIT layers aa_layer = BlurPool2d
space_to_depth = SpaceToDepthModule()
aa_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit)
# TResnet stages # TResnet stages
self.inplanes = int(64 * width_factor) self.inplanes = int(64 * width_factor)
@ -181,7 +175,7 @@ class TResNet(nn.Module):
# body # body
self.body = nn.Sequential(OrderedDict([ self.body = nn.Sequential(OrderedDict([
('SpaceToDepth', space_to_depth), ('SpaceToDepth', SpaceToDepthModule()),
('conv1', conv1), ('conv1', conv1),
('layer1', layer1), ('layer1', layer1),
('layer2', layer2), ('layer2', layer2),

Loading…
Cancel
Save