From 569419b38d566ecb492009abb05b5113b9265220 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 15 Feb 2020 21:18:25 -0800 Subject: [PATCH] Tweak some comments, add SKNet models with weights to sotabench, remove an unused branch --- sotabench.py | 18 +++++++++++++----- timm/models/layers/cbam.py | 3 +++ timm/models/layers/cond_conv2d.py | 7 +++++-- timm/models/layers/eca.py | 2 +- timm/models/layers/mixed_conv2d.py | 4 +++- timm/models/layers/selective_kernel.py | 4 +++- timm/models/sknet.py | 26 ++++++++++++-------------- 7 files changed, 40 insertions(+), 24 deletions(-) diff --git a/sotabench.py b/sotabench.py index 054d61c7..217c2a81 100644 --- a/sotabench.py +++ b/sotabench.py @@ -56,8 +56,7 @@ model_list = [ model_desc='Trained from scratch in PyTorch w/ RandAugment'), _entry('efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946', model_desc='Trained from scratch in PyTorch w/ RandAugment'), - _entry('fbnetc_100', 'FBNet-C', '1812.03443', - model_desc='Trained in PyTorch with RMSProp, exponential LR decay'), + _entry('gluon_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from GluonCV Model Zoo'), _entry('gluon_resnet18_v1b', 'ResNet-18', '1812.01187', model_desc='Ported from GluonCV Model Zoo'), _entry('gluon_resnet34_v1b', 'ResNet-34', '1812.01187', model_desc='Ported from GluonCV Model Zoo'), @@ -82,14 +81,22 @@ model_list = [ _entry('gluon_seresnext101_64x4d', 'SE-ResNeXt-101 64x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'), _entry('gluon_xception65', 'Modified Aligned Xception', '1802.02611', batch_size=BATCH_SIZE//2, model_desc='Ported from GluonCV Model Zoo'), + _entry('mixnet_xl', 'MixNet-XL', '1907.09595', model_desc="My own scaling beyond paper's MixNet Large"), _entry('mixnet_l', 'MixNet-L', '1907.09595'), _entry('mixnet_m', 'MixNet-M', '1907.09595'), _entry('mixnet_s', 'MixNet-S', '1907.09595'), + + _entry('fbnetc_100', 'FBNet-C', '1812.03443', + model_desc='Trained in PyTorch with RMSProp, exponential LR decay'), _entry('mnasnet_100', 'MnasNet-B1', '1807.11626'), + _entry('semnasnet_100', 'MnasNet-A1', '1807.11626'), + _entry('spnasnet_100', 'Single-Path NAS', '1904.02877', + model_desc='Trained in PyTorch with SGD, cosine LR decay'), _entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244', model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching ' 'paper as closely as possible.'), + _entry('resnet18', 'ResNet-18', '1812.01187'), _entry('resnet26', 'ResNet-26', '1812.01187', model_desc='Block cfg of ResNet-34 w/ Bottleneck'), _entry('resnet26d', 'ResNet-26-D', '1812.01187', @@ -103,7 +110,7 @@ model_list = [ _entry('resnext50d_32x4d', 'ResNeXt-50-D 32x4d', '1812.01187', model_desc="'D' variant (3x3 deep stem w/ avg-pool downscale). Trained with " "SGD w/ cosine LR decay, random-erasing (gaussian per-pixel noise) and label-smoothing"), - _entry('semnasnet_100', 'MnasNet-A1', '1807.11626'), + _entry('seresnet18', 'SE-ResNet-18', '1709.01507'), _entry('seresnet34', 'SE-ResNet-34', '1709.01507'), _entry('seresnext26_32x4d', 'SE-ResNeXt-26 32x4d', '1709.01507', @@ -114,8 +121,9 @@ model_list = [ model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered stem, and avg-pool in downsample layers.'), _entry('seresnext26tn_32x4d', 'SE-ResNeXt-26-TN 32x4d', '1812.01187', model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered narrow stem, and avg-pool in downsample layers.'), - _entry('spnasnet_100', 'Single-Path NAS', '1904.02877', - model_desc='Trained in PyTorch with SGD, cosine LR decay'), + + _entry('skresnet18', 'SK-ResNet-18', '1903.06586'), + _entry('skresnext50_32x4d', 'SKNet-50', '1903.06586'), _entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946', model_desc='Ported from official Google AI Tensorflow weights'), diff --git a/timm/models/layers/cbam.py b/timm/models/layers/cbam.py index 37ba1c35..81c0b6b3 100644 --- a/timm/models/layers/cbam.py +++ b/timm/models/layers/cbam.py @@ -2,6 +2,9 @@ Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 +WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on +some tasks, especially fine-grained it seems. I may end up removing this impl. + Hacked together by Ross Wightman """ diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py index a7a424a6..7b038ee7 100644 --- a/timm/models/layers/cond_conv2d.py +++ b/timm/models/layers/cond_conv2d.py @@ -1,4 +1,7 @@ -""" Conditional Convolution +""" PyTorch Conditionally Parameterized Convolution (CondConv) + +Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference +(https://arxiv.org/abs/1904.04971) Hacked together by Ross Wightman """ @@ -28,7 +31,7 @@ def get_condconv_initializer(initializer, num_experts, expert_shape): class CondConv2d(nn.Module): - """ Conditional Convolution + """ Conditionally Parameterized Convolution Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index 7ca5033d..cfc39710 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -42,7 +42,7 @@ class EcaModule(nn.Module): """Constructs an ECA module. Args: - channel: Number of channels of the input feature map for use in adaptive kernel sizes + channels: Number of channels of the input feature map for use in adaptive kernel sizes for actual calculations according to channel. gamma, beta: when channel is given parameters of mapping function refer to original paper https://arxiv.org/pdf/1910.03151.pdf diff --git a/timm/models/layers/mixed_conv2d.py b/timm/models/layers/mixed_conv2d.py index 3e280c03..1da469b3 100644 --- a/timm/models/layers/mixed_conv2d.py +++ b/timm/models/layers/mixed_conv2d.py @@ -1,4 +1,6 @@ -""" Conditional Convolution +""" PyTorch Mixed Convolution + +Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) Hacked together by Ross Wightman """ diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index 4100aa02..cb7e29ad 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -1,4 +1,6 @@ -""" Selective Kernel Convolution Attention +""" Selective Kernel Convolution/Attention + +Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) Hacked together by Ross Wightman """ diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 97cd84dd..7737bfed 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -1,3 +1,9 @@ +""" Selective Kernel Networks (ResNet base) + +Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) + +Hacked together by Ross Wightman +""" import math from torch import nn as nn @@ -47,19 +53,11 @@ class SelectiveKernelBasic(nn.Module): outplanes = planes * self.expansion first_dilation = first_dilation or dilation - _selective_first = True # FIXME temporary, for experiments - if _selective_first: - self.conv1 = SelectiveKernelConv( - inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) - conv_kwargs['act_layer'] = None - self.conv2 = ConvBnAct( - first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs) - else: - self.conv1 = ConvBnAct( - inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs) - conv_kwargs['act_layer'] = None - self.conv2 = SelectiveKernelConv( - first_planes, outplanes, dilation=dilation, **conv_kwargs, **sk_kwargs) + self.conv1 = SelectiveKernelConv( + inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv2 = ConvBnAct( + first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs) self.se = create_attn(attn_layer, outplanes) self.act = act_layer(inplace=True) self.downsample = downsample @@ -222,7 +220,7 @@ def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to - the SKNet50 model in the Select Kernel Paper + the SKNet-50 model in the Select Kernel Paper """ default_cfg = default_cfgs['skresnext50_32x4d'] model = ResNet(