From 0be1fa4793aed6d926dff43e59491588c744bd89 Mon Sep 17 00:00:00 2001 From: Michael Monashev Date: Sun, 11 Apr 2021 18:08:43 +0300 Subject: [PATCH 1/8] Argument description fixed --- benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark.py b/benchmark.py index 4812d85c..dc76550a 100755 --- a/benchmark.py +++ b/benchmark.py @@ -48,7 +48,7 @@ parser = argparse.ArgumentParser(description='PyTorch Benchmark') parser.add_argument('--model-list', metavar='NAME', default='', help='txt file based list of model names to benchmark') parser.add_argument('--bench', default='both', type=str, - help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'") + help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'") parser.add_argument('--detail', action='store_true', default=False, help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', From 08d60f4a9a385b1eaf132a0ac1cff37b4ad94ac8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 4 May 2021 12:41:09 -0700 Subject: [PATCH 2/8] resnetrs50 pool sizing wrong --- timm/models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 5355d61d..cccc9ae0 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -242,7 +242,7 @@ default_cfgs = { # ResNet-RS models 'resnetrs50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50-7c9728e2.pth', - input_size=(3, 160, 160), pool_size=(4, 4), crop_pct=0.91, test_input_size=(3, 224, 224), + input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.91, test_input_size=(3, 224, 224), interpolation='bicubic', first_conv='conv1.0'), 'resnetrs101': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101-3e4bb55c.pth', From ddc743fdf890bc79a48dfca548ec2286438b93f2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 4 May 2021 16:16:55 -0700 Subject: [PATCH 3/8] Update ResNet-RS models to EMA weights --- timm/models/resnet.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index cccc9ae0..7fd47057 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -241,31 +241,31 @@ default_cfgs = { # ResNet-RS models 'resnetrs50': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50-7c9728e2.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50_ema-6b53758b.pth', input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.91, test_input_size=(3, 224, 224), interpolation='bicubic', first_conv='conv1.0'), 'resnetrs101': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101-3e4bb55c.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101_i192_ema-1509bbf6.pth', input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288), interpolation='bicubic', first_conv='conv1.0'), 'resnetrs152': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152-b1efe56d.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152_i256_ema-a9aff7f9.pth', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), interpolation='bicubic', first_conv='conv1.0'), 'resnetrs200': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200-b455b791.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200_ema-623d2f59.pth', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), interpolation='bicubic', first_conv='conv1.0'), 'resnetrs270': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270-cafcfbc7.pth', - input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270_ema-b40e674c.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 352, 352), interpolation='bicubic', first_conv='conv1.0'), 'resnetrs350': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350-06d9bfac.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350_i256_ema-5a1aa8f1.pth', input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384), interpolation='bicubic', first_conv='conv1.0'), 'resnetrs420': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420-d26764a5.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420_ema-972dee69.pth', input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416), interpolation='bicubic', first_conv='conv1.0'), } From 0d87650fea5ce607d07204806e8143a48917d96f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 4 May 2021 16:56:28 -0700 Subject: [PATCH 4/8] Remove filter hack from BlurPool w/ non-persistent buffer. Use BlurPool2d instead of AntiAliasing.. for TResNet. Breaks PyTorch < 1.6. --- timm/models/layers/__init__.py | 1 - timm/models/layers/anti_aliasing.py | 60 ----------------------------- timm/models/layers/blur_pool.py | 32 ++++----------- timm/models/tresnet.py | 14 ++----- 4 files changed, 12 insertions(+), 95 deletions(-) delete mode 100644 timm/models/layers/anti_aliasing.py diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index ac0b6b41..eecbbde4 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1,7 +1,6 @@ from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d -from .anti_aliasing import AntiAliasDownsampleLayer from .blur_pool import BlurPool2d from .classifier import ClassifierHead, create_classifier from .cond_conv2d import CondConv2d, get_condconv_initializer diff --git a/timm/models/layers/anti_aliasing.py b/timm/models/layers/anti_aliasing.py deleted file mode 100644 index 9d3837e8..00000000 --- a/timm/models/layers/anti_aliasing.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -import torch.nn.parallel -import torch.nn as nn -import torch.nn.functional as F - - -class AntiAliasDownsampleLayer(nn.Module): - def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False): - super(AntiAliasDownsampleLayer, self).__init__() - if no_jit: - self.op = Downsample(channels, filt_size, stride) - else: - self.op = DownsampleJIT(channels, filt_size, stride) - - # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls - - def forward(self, x): - return self.op(x) - - -@torch.jit.script -class DownsampleJIT(object): - def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2): - self.channels = channels - self.stride = stride - self.filt_size = filt_size - assert self.filt_size == 3 - assert stride == 2 - self.filt = {} # lazy init by device for DataParallel compat - - def _create_filter(self, like: torch.Tensor): - filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) - filt = filt[:, None] * filt[None, :] - filt = filt / torch.sum(filt) - return filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) - - def __call__(self, input: torch.Tensor): - input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') - filt = self.filt.get(str(input.device), self._create_filter(input)) - return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1]) - - -class Downsample(nn.Module): - def __init__(self, channels=None, filt_size=3, stride=2): - super(Downsample, self).__init__() - self.channels = channels - self.filt_size = filt_size - self.stride = stride - - assert self.filt_size == 3 - filt = torch.tensor([1., 2., 1.]) - filt = filt[:, None] * filt[None, :] - filt = filt / torch.sum(filt) - - # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) - self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) - - def forward(self, input): - input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') - return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) diff --git a/timm/models/layers/blur_pool.py b/timm/models/layers/blur_pool.py index 399cbe35..ca4ce756 100644 --- a/timm/models/layers/blur_pool.py +++ b/timm/models/layers/blur_pool.py @@ -3,8 +3,6 @@ BlurPool layer inspired by - Kornia's Max_BlurPool2d - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` -FIXME merge this impl with those in `anti_aliasing.py` - Hacked together by Chris Ha and Ross Wightman """ @@ -12,7 +10,6 @@ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np -from typing import Dict from .padding import get_padding @@ -29,30 +26,17 @@ class BlurPool2d(nn.Module): Returns: torch.Tensor: the transformed tensor. """ - filt: Dict[str, torch.Tensor] - def __init__(self, channels, filt_size=3, stride=2) -> None: super(BlurPool2d, self).__init__() assert filt_size > 1 self.channels = channels self.filt_size = filt_size self.stride = stride - pad_size = [get_padding(filt_size, stride, dilation=1)] * 4 - self.padding = nn.ReflectionPad2d(pad_size) - self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat - self.filt = {} # lazy init by device for DataParallel compat - - def _create_filter(self, like: torch.Tensor): - blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device) - return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1) - - def _apply(self, fn): - # override nn.Module _apply, reset filter cache if used - self.filt = {} - super(BlurPool2d, self)._apply(fn) - - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - C = input_tensor.shape[1] - blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor)) - return F.conv2d( - self.padding(input_tensor), blur_filt, stride=self.stride, groups=C) + self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 + coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) + blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) + self.register_buffer('filt', blur_filter, persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding, 'reflect') + return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1]) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index ee1f3fc1..cec51cf4 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -5,16 +5,13 @@ https://arxiv.org/pdf/2003.13630.pdf Original model: https://github.com/mrT23/TResNet """ -import copy from collections import OrderedDict -from functools import partial import torch import torch.nn as nn -import torch.nn.functional as F from .helpers import build_model_with_cfg -from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule +from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule from .registry import register_model __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] @@ -156,15 +153,12 @@ class Bottleneck(nn.Module): class TResNet(nn.Module): - def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False, - global_pool='fast', drop_rate=0.): + def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, global_pool='fast', drop_rate=0.): self.num_classes = num_classes self.drop_rate = drop_rate super(TResNet, self).__init__() - # JIT layers - space_to_depth = SpaceToDepthModule() - aa_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit) + aa_layer = BlurPool2d # TResnet stages self.inplanes = int(64 * width_factor) @@ -181,7 +175,7 @@ class TResNet(nn.Module): # body self.body = nn.Sequential(OrderedDict([ - ('SpaceToDepth', space_to_depth), + ('SpaceToDepth', SpaceToDepthModule()), ('conv1', conv1), ('layer1', layer1), ('layer2', layer2), From d5473c17f77d608ee150ef09b0a7c8d590f77aee Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 4 May 2021 21:27:15 -0700 Subject: [PATCH 5/8] Fix incorrect name of shortcut/identity paths in many residual nets. Inherited from naming in old old torchvision, long fixed there. --- timm/models/dla.py | 54 +++++++++++++++--------------- timm/models/efficientnet_blocks.py | 16 ++++----- timm/models/ghostnet.py | 4 +-- timm/models/res2net.py | 6 ++-- timm/models/resnest.py | 6 ++-- timm/models/resnet.py | 12 +++---- timm/models/senet.py | 12 +++---- timm/models/sknet.py | 12 +++---- timm/models/tresnet.py | 12 +++---- 9 files changed, 67 insertions(+), 67 deletions(-) diff --git a/timm/models/dla.py b/timm/models/dla.py index 64ad61d6..f0f25b0b 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -62,9 +62,9 @@ class DlaBasic(nn.Module): self.bn2 = nn.BatchNorm2d(planes) self.stride = stride - def forward(self, x, residual=None): - if residual is None: - residual = x + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -73,7 +73,7 @@ class DlaBasic(nn.Module): out = self.conv2(out) out = self.bn2(out) - out += residual + out += shortcut out = self.relu(out) return out @@ -99,9 +99,9 @@ class DlaBottleneck(nn.Module): self.bn3 = nn.BatchNorm2d(outplanes) self.relu = nn.ReLU(inplace=True) - def forward(self, x, residual=None): - if residual is None: - residual = x + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -114,7 +114,7 @@ class DlaBottleneck(nn.Module): out = self.conv3(out) out = self.bn3(out) - out += residual + out += shortcut out = self.relu(out) return out @@ -154,9 +154,9 @@ class DlaBottle2neck(nn.Module): self.bn3 = nn.BatchNorm2d(outplanes) self.relu = nn.ReLU(inplace=True) - def forward(self, x, residual=None): - if residual is None: - residual = x + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -177,26 +177,26 @@ class DlaBottle2neck(nn.Module): out = self.conv3(out) out = self.bn3(out) - out += residual + out += shortcut out = self.relu(out) return out class DlaRoot(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, residual): + def __init__(self, in_channels, out_channels, kernel_size, shortcut): super(DlaRoot, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) - self.residual = residual + self.shortcut = shortcut def forward(self, *x): children = x x = self.conv(torch.cat(x, 1)) x = self.bn(x) - if self.residual: + if self.shortcut: x += children[0] x = self.relu(x) @@ -206,7 +206,7 @@ class DlaRoot(nn.Module): class DlaTree(nn.Module): def __init__(self, levels, block, in_channels, out_channels, stride=1, dilation=1, cardinality=1, base_width=64, - level_root=False, root_dim=0, root_kernel_size=1, root_residual=False): + level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False): super(DlaTree, self).__init__() if root_dim == 0: root_dim = 2 * out_channels @@ -226,24 +226,24 @@ class DlaTree(nn.Module): nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(out_channels)) else: - cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual)) + cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut)) self.tree1 = DlaTree( levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs) self.tree2 = DlaTree( levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs) if levels == 1: - self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual) + self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut) self.level_root = level_root self.root_dim = root_dim self.levels = levels - def forward(self, x, residual=None, children=None): + def forward(self, x, shortcut=None, children=None): children = [] if children is None else children bottom = self.downsample(x) - residual = self.project(bottom) + shortcut = self.project(bottom) if self.level_root: children.append(bottom) - x1 = self.tree1(x, residual) + x1 = self.tree1(x, shortcut) if self.levels == 1: x2 = self.tree2(x1) x = self.root(x2, x1, *children) @@ -255,7 +255,7 @@ class DlaTree(nn.Module): class DLA(nn.Module): def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, - cardinality=1, base_width=64, block=DlaBottle2neck, residual_root=False, + cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False, drop_rate=0.0, global_pool='avg'): super(DLA, self).__init__() self.channels = channels @@ -271,7 +271,7 @@ class DLA(nn.Module): nn.ReLU(inplace=True)) self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) - cargs = dict(cardinality=cardinality, base_width=base_width, root_residual=residual_root) + cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root) self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs) self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs) self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs) @@ -413,7 +413,7 @@ def dla60x(pretrained=False, **kwargs): # DLA-X-60 def dla102(pretrained=False, **kwargs): # DLA-102 model_kwargs = dict( levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, residual_root=True, **kwargs) + block=DlaBottleneck, shortcut_root=True, **kwargs) return _create_dla('dla102', pretrained, **model_kwargs) @@ -421,7 +421,7 @@ def dla102(pretrained=False, **kwargs): # DLA-102 def dla102x(pretrained=False, **kwargs): # DLA-X-102 model_kwargs = dict( levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, **kwargs) + block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs) return _create_dla('dla102x', pretrained, **model_kwargs) @@ -429,7 +429,7 @@ def dla102x(pretrained=False, **kwargs): # DLA-X-102 def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64 model_kwargs = dict( levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, **kwargs) + block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs) return _create_dla('dla102x2', pretrained, **model_kwargs) @@ -437,5 +437,5 @@ def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64 def dla169(pretrained=False, **kwargs): # DLA-169 model_kwargs = dict( levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, residual_root=True, **kwargs) + block=DlaBottleneck, shortcut_root=True, **kwargs) return _create_dla('dla169', pretrained, **model_kwargs) diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 114533cf..040785f6 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -184,7 +184,7 @@ class DepthwiseSeparableConv(nn.Module): return info def forward(self, x): - residual = x + shortcut = x x = self.conv_dw(x) x = self.bn1(x) @@ -200,7 +200,7 @@ class DepthwiseSeparableConv(nn.Module): if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) - x += residual + x += shortcut return x @@ -258,7 +258,7 @@ class InvertedResidual(nn.Module): return info def forward(self, x): - residual = x + shortcut = x # Point-wise expansion x = self.conv_pw(x) @@ -281,7 +281,7 @@ class InvertedResidual(nn.Module): if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) - x += residual + x += shortcut return x @@ -308,7 +308,7 @@ class CondConvResidual(InvertedResidual): self.routing_fn = nn.Linear(in_chs, self.num_experts) def forward(self, x): - residual = x + shortcut = x # CondConv routing pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) @@ -335,7 +335,7 @@ class CondConvResidual(InvertedResidual): if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) - x += residual + x += shortcut return x @@ -390,7 +390,7 @@ class EdgeResidual(nn.Module): return info def forward(self, x): - residual = x + shortcut = x # Expansion convolution x = self.conv_exp(x) @@ -408,6 +408,6 @@ class EdgeResidual(nn.Module): if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) - x += residual + x += shortcut return x diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 76761d1c..358fb4c7 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -112,7 +112,7 @@ class GhostBottleneck(nn.Module): def forward(self, x): - residual = x + shortcut = x # 1st ghost bottleneck x = self.ghost1(x) @@ -129,7 +129,7 @@ class GhostBottleneck(nn.Module): # 2nd ghost bottleneck x = self.ghost2(x) - x += self.shortcut(residual) + x += self.shortcut(shortcut) return x diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 977d872f..282baba3 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -91,7 +91,7 @@ class Bottle2neck(nn.Module): nn.init.zeros_(self.bn3.weight) def forward(self, x): - residual = x + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -124,9 +124,9 @@ class Bottle2neck(nn.Module): out = self.se(out) if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) - out += residual + out += shortcut out = self.relu(out) return out diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 154e250c..ac3b2559 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -105,7 +105,7 @@ class ResNestBottleneck(nn.Module): nn.init.zeros_(self.bn3.weight) def forward(self, x): - residual = x + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -132,9 +132,9 @@ class ResNestBottleneck(nn.Module): out = self.drop_block(out) if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) - out += residual + out += shortcut out = self.act3(out) return out diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 7fd47057..491d9acb 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -315,7 +315,7 @@ class BasicBlock(nn.Module): nn.init.zeros_(self.bn2.weight) def forward(self, x): - residual = x + shortcut = x x = self.conv1(x) x = self.bn1(x) @@ -337,8 +337,8 @@ class BasicBlock(nn.Module): x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(residual) - x += residual + shortcut = self.downsample(shortcut) + x += shortcut x = self.act2(x) return x @@ -385,7 +385,7 @@ class Bottleneck(nn.Module): nn.init.zeros_(self.bn3.weight) def forward(self, x): - residual = x + shortcut = x x = self.conv1(x) x = self.bn1(x) @@ -413,8 +413,8 @@ class Bottleneck(nn.Module): x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(residual) - x += residual + shortcut = self.downsample(shortcut) + x += shortcut x = self.act3(x) return x diff --git a/timm/models/senet.py b/timm/models/senet.py index 8227a453..3d0ba7b3 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -92,7 +92,7 @@ class Bottleneck(nn.Module): """ def forward(self, x): - residual = x + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -106,9 +106,9 @@ class Bottleneck(nn.Module): out = self.bn3(out) if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) - out = self.se_module(out) + residual + out = self.se_module(out) + shortcut out = self.relu(out) return out @@ -204,7 +204,7 @@ class SEResNetBlock(nn.Module): self.stride = stride def forward(self, x): - residual = x + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -215,9 +215,9 @@ class SEResNetBlock(nn.Module): out = self.relu(out) if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) - out = self.se_module(out) + residual + out = self.se_module(out) + shortcut out = self.relu(out) return out diff --git a/timm/models/sknet.py b/timm/models/sknet.py index bd9dd393..eb7ad8c3 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -76,7 +76,7 @@ class SelectiveKernelBasic(nn.Module): nn.init.zeros_(self.conv2.bn.weight) def forward(self, x): - residual = x + shortcut = x x = self.conv1(x) x = self.conv2(x) if self.se is not None: @@ -84,8 +84,8 @@ class SelectiveKernelBasic(nn.Module): if self.drop_path is not None: x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(residual) - x += residual + shortcut = self.downsample(shortcut) + x += shortcut x = self.act(x) return x @@ -124,7 +124,7 @@ class SelectiveKernelBottleneck(nn.Module): nn.init.zeros_(self.conv3.bn.weight) def forward(self, x): - residual = x + shortcut = x x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) @@ -133,8 +133,8 @@ class SelectiveKernelBottleneck(nn.Module): if self.drop_path is not None: x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(residual) - x += residual + shortcut = self.downsample(shortcut) + x += shortcut x = self.act(x) return x diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index cec51cf4..9fb34c20 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -89,9 +89,9 @@ class BasicBlock(nn.Module): def forward(self, x): if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) else: - residual = x + shortcut = x out = self.conv1(x) out = self.conv2(out) @@ -99,7 +99,7 @@ class BasicBlock(nn.Module): if self.se is not None: out = self.se(out) - out += residual + out += shortcut out = self.relu(out) return out @@ -136,9 +136,9 @@ class Bottleneck(nn.Module): def forward(self, x): if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) else: - residual = x + shortcut = x out = self.conv1(x) out = self.conv2(out) @@ -146,7 +146,7 @@ class Bottleneck(nn.Module): out = self.se(out) out = self.conv3(out) - out = out + residual # no inplace + out = out + shortcut # no inplace out = self.relu(out) return out From 072155951104230c2b5f3bbfb31acc694ee2fa0a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 4 May 2021 21:40:39 -0700 Subject: [PATCH 6/8] Improved (hopefully) init for SA/SA-like layers used in ByoaNets --- timm/models/byoanet.py | 2 ++ timm/models/layers/bottleneck_attn.py | 6 ++++++ timm/models/layers/halo_attn.py | 9 +++++++++ timm/models/layers/lambda_layer.py | 6 ++++++ timm/models/layers/swin_attn.py | 6 +++++- 5 files changed, 28 insertions(+), 1 deletion(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index df88535d..da9e513b 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -294,6 +294,8 @@ class SelfAttnBlock(nn.Module): def init_weights(self, zero_init_last_bn=False): if zero_init_last_bn: nn.init.zeros_(self.conv3_1x1.bn.weight) + if hasattr(self.self_attn, 'reset_parameters'): + self.self_attn.reset_parameters() def forward(self, x): shortcut = self.shortcut(x) diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index 0bb0e27b..9604e8a6 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -21,6 +21,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import to_2tuple +from .weight_init import trunc_normal_ def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -101,6 +102,11 @@ class BottleneckAttn(nn.Module): self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + def forward(self, x): B, C, H, W = x.shape assert H == self.pos_embed.height and W == self.pos_embed.width diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 8452aa94..87cae895 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -25,6 +25,8 @@ import torch from torch import nn import torch.nn.functional as F +from .weight_init import trunc_normal_ + def rel_logits_1d(q, rel_k, permute_mask: List[int]): """ Compute relative logits along one dimension @@ -124,6 +126,13 @@ class HaloAttn(nn.Module): self.pos_embed = PosEmbedRel( block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) + def reset_parameters(self): + std = self.q.weight.shape[1] ** -0.5 # fan-in + trunc_normal_(self.q.weight, std=std) + trunc_normal_(self.kv.weight, std=std) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + def forward(self, x): B, C, H, W = x.shape assert H % self.block_size == 0 and W % self.block_size == 0 diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index c89982af..2d1027a1 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -24,6 +24,7 @@ import torch from torch import nn import torch.nn.functional as F +from .weight_init import trunc_normal_ class LambdaLayer(nn.Module): @@ -36,6 +37,7 @@ class LambdaLayer(nn.Module): self, dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): super().__init__() + self.dim = dim self.dim_out = dim_out or dim self.dim_k = dim_head # query depth 'k' self.num_heads = num_heads @@ -55,6 +57,10 @@ class LambdaLayer(nn.Module): self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) + trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + def forward(self, x): B, C, H, W = x.shape M = H * W diff --git a/timm/models/layers/swin_attn.py b/timm/models/layers/swin_attn.py index 46dacb62..02131bbc 100644 --- a/timm/models/layers/swin_attn.py +++ b/timm/models/layers/swin_attn.py @@ -107,6 +107,7 @@ class WindowAttention(nn.Module): self.relative_position_bias_table = nn.Parameter( # 2 * Wh - 1 * 2 * Ww - 1, nH torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads)) + trunc_normal_(self.relative_position_bias_table, std=.02) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.win_size) @@ -120,13 +121,16 @@ class WindowAttention(nn.Module): relative_coords[:, :, 0] *= 2 * self.win_size - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) - trunc_normal_(self.relative_position_bias_table, std=.02) self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.softmax = nn.Softmax(dim=-1) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.relative_position_bias_table, std=.02) + def forward(self, x): B, C, H, W = x.shape x = x.permute(0, 2, 3, 1) From 12efffa6b108230fb27ca200f5c7063652341122 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 5 May 2021 00:59:45 -0700 Subject: [PATCH 7/8] Initial MLP-Mixer attempt... --- timm/models/__init__.py | 1 + timm/models/mlp_mixer.py | 142 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 timm/models/mlp_mixer.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 400e1f64..4cc96321 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -14,6 +14,7 @@ from .hrnet import * from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * +from .mlp_mixer import * from .mobilenetv3 import * from .nasnet import * from .nfnet import * diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py new file mode 100644 index 00000000..612345a0 --- /dev/null +++ b/timm/models/mlp_mixer.py @@ -0,0 +1,142 @@ +""" MLP-Mixer in PyTorch + +Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + +NOTE this is a very early stage first run through, the param counts aren't matching paper so +something is up... +""" +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import DropPath, to_2tuple, trunc_normal_, lecun_normal_ +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.patch_grid[0] * self.patch_grid[1] + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +class MixerBlock(nn.Module): + + def __init__( + self, dim, seq_len, tokens_dim, channels_dim, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): + super().__init__() + self.norm1 = norm_layer(dim) + self.mlp_token = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.mlp_token(self.norm1(x).transpose(1, 2)).transpose(1, 2)) + x = x + self.drop_path(self.mlp_channels(self.norm2(x))) + return x + + +class MlpMixer(nn.Module): + + def __init__( + self, + num_classes=1000, + img_size=224, + in_chans=3, + patch_size=16, + num_blocks=8, + hidden_dim=512, + tokens_dim=256, + channels_dim=2048, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + drop=0., + drop_path=0., + ): + super().__init__() + self.num_classes = num_classes + + self.stem = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim) + # FIXME drop_path (stochastic depth scaling rule?) + self.blocks = nn.Sequential(*[ + MixerBlock( + hidden_dim, self.stem.num_patches, tokens_dim, channels_dim, + norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path) + for _ in range(num_blocks)]) + self.norm = nn.LayerNorm(hidden_dim) + self.head = nn.Linear(hidden_dim, self.num_classes) # zero init + + def forward(self, x): + x = self.stem(x) + x = self.blocks(x) + x = self.norm(x) + x = x.mean(dim=1) + x = self.head(x) + return x + + +@register_model +def mixer_small_p16(pretrained=False, **kwargs): + model = MlpMixer() + model.default_cfg = _cfg() + return model + + +@register_model +def mixer_base_p16(pretrained=False, **kwargs): + model = MlpMixer(num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072) + model.default_cfg = _cfg() + return model \ No newline at end of file From 2d8b09fe8bd846c72ca6a9a5fc31927e90e41628 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 5 May 2021 15:59:40 -0700 Subject: [PATCH 8/8] Add official pretrained weights to MLP-Mixer, complete model cfgs. --- tests/test_models.py | 2 +- timm/models/mlp_mixer.py | 182 +++++++++++++++++++++++++++++++++++---- 2 files changed, 167 insertions(+), 17 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 0d3fde76..2b7a9143 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): torch._C._jit_set_profiling_mode(False) # transformer models don't support many of the spatial / feature based model functionalities -NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*'] +NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'mixer_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 612345a0..e044e961 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -1,10 +1,23 @@ """ MLP-Mixer in PyTorch +Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 -NOTE this is a very early stage first run through, the param counts aren't matching paper so -something is up... +@article{tolstikhin2021, + title={MLP-Mixer: An all-MLP Architecture for Vision}, + author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner, + Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey}, + journal={arXiv preprint arXiv:2105.01601}, + year={2021} +} + +A thank you to paper authors for releasing code and weights. + +Hacked together by / Copyright 2021 Ross Wightman """ +import math +from copy import deepcopy from functools import partial import torch @@ -12,7 +25,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import DropPath, to_2tuple, trunc_normal_, lecun_normal_ +from .layers import DropPath, to_2tuple, lecun_normal_ from .registry import register_model @@ -20,14 +33,39 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'stem.proj', 'classifier': 'head', **kwargs } +default_cfgs = dict( + mixer_s32_224=_cfg(), + mixer_s16_224=_cfg(), + mixer_b32_224=_cfg(), + mixer_b16_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth', + ), + mixer_b16_224_in21k=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth', + num_classes=21843 + ), + mixer_l32_224=_cfg(), + mixer_l16_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth', + ), + mixer_l16_224_in21k=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth', + num_classes=21843 + ), +) + + class Mlp(nn.Module): + """ MLP Block + NOTE: same impl as ViT, move to common location + """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features @@ -48,6 +86,7 @@ class Mlp(nn.Module): class PatchEmbed(nn.Module): """ Image to Patch Embedding + NOTE: same impl as ViT, move to common location """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): super().__init__() @@ -78,13 +117,13 @@ class MixerBlock(nn.Module): norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): super().__init__() self.norm1 = norm_layer(dim) - self.mlp_token = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop) + self.mlp_tokens = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop) def forward(self, x): - x = x + self.drop_path(self.mlp_token(self.norm1(x).transpose(1, 2)).transpose(1, 2)) + x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) x = x + self.drop_path(self.mlp_channels(self.norm2(x))) return x @@ -105,6 +144,7 @@ class MlpMixer(nn.Module): act_layer=nn.GELU, drop=0., drop_path=0., + nlhb=False, ): super().__init__() self.num_classes = num_classes @@ -116,9 +156,16 @@ class MlpMixer(nn.Module): hidden_dim, self.stem.num_patches, tokens_dim, channels_dim, norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path) for _ in range(num_blocks)]) - self.norm = nn.LayerNorm(hidden_dim) + self.norm = norm_layer(hidden_dim) self.head = nn.Linear(hidden_dim, self.num_classes) # zero init + self.init_weights(nlhb=nlhb) + + def init_weights(self, nlhb=False): + head_bias = -math.log(self.num_classes) if nlhb else 0. + for n, m in self.named_modules(): + _init_weights(m, n, head_bias=head_bias) + def forward(self, x): x = self.stem(x) x = self.blocks(x) @@ -128,15 +175,118 @@ class MlpMixer(nn.Module): return x +def _init_weights(m, n: str, head_bias: float = 0.): + """ Mixer weight initialization (trying to match Flax defaults) + """ + if isinstance(m, nn.Linear): + if n.startswith('head'): + nn.init.zeros_(m.weight) + nn.init.constant_(m.bias, head_bias) + else: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if 'mlp' in n: + nn.init.normal_(m.bias, std=1e-6) + else: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + lecun_normal_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.zeros_(m.bias) + nn.init.ones_(m.weight) + + +def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for MLP-Mixer models.') + + model = build_model_with_cfg( + MlpMixer, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + **kwargs) + + return model + + +@register_model +def mixer_s32_224(pretrained=False, **kwargs): + """ Mixer-S/32 224x224 + """ + model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs) + model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_s16_224(pretrained=False, **kwargs): + """ Mixer-S/16 224x224 + """ + model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs) + model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b32_224(pretrained=False, **kwargs): + """ Mixer-B/32 224x224 + """ + model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) + model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args) + return model + + @register_model -def mixer_small_p16(pretrained=False, **kwargs): - model = MlpMixer() - model.default_cfg = _cfg() +def mixer_b16_224(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. + """ + model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) + model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b16_224_in21k(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. + """ + model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) + model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args) return model @register_model -def mixer_base_p16(pretrained=False, **kwargs): - model = MlpMixer(num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072) - model.default_cfg = _cfg() - return model \ No newline at end of file +def mixer_l32_224(pretrained=False, **kwargs): + """ Mixer-L/32 224x224. + """ + model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs) + model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_l16_224(pretrained=False, **kwargs): + """ Mixer-L/16 224x224. ImageNet-1k pretrained weights. + """ + model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs) + model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_l16_224_in21k(pretrained=False, **kwargs): + """ Mixer-L/16 224x224. ImageNet-21k pretrained weights. + """ + model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs) + model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args) + return model