Tweak some comments, add SKNet models with weights to sotabench, remove an unused branch

pull/88/head
Ross Wightman 5 years ago
parent 91e2b33d72
commit 569419b38d

@ -56,8 +56,7 @@ model_list = [
model_desc='Trained from scratch in PyTorch w/ RandAugment'), model_desc='Trained from scratch in PyTorch w/ RandAugment'),
_entry('efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946', _entry('efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946',
model_desc='Trained from scratch in PyTorch w/ RandAugment'), model_desc='Trained from scratch in PyTorch w/ RandAugment'),
_entry('fbnetc_100', 'FBNet-C', '1812.03443',
model_desc='Trained in PyTorch with RMSProp, exponential LR decay'),
_entry('gluon_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from GluonCV Model Zoo'), _entry('gluon_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet18_v1b', 'ResNet-18', '1812.01187', model_desc='Ported from GluonCV Model Zoo'), _entry('gluon_resnet18_v1b', 'ResNet-18', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet34_v1b', 'ResNet-34', '1812.01187', model_desc='Ported from GluonCV Model Zoo'), _entry('gluon_resnet34_v1b', 'ResNet-34', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
@ -82,14 +81,22 @@ model_list = [
_entry('gluon_seresnext101_64x4d', 'SE-ResNeXt-101 64x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'), _entry('gluon_seresnext101_64x4d', 'SE-ResNeXt-101 64x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_xception65', 'Modified Aligned Xception', '1802.02611', batch_size=BATCH_SIZE//2, _entry('gluon_xception65', 'Modified Aligned Xception', '1802.02611', batch_size=BATCH_SIZE//2,
model_desc='Ported from GluonCV Model Zoo'), model_desc='Ported from GluonCV Model Zoo'),
_entry('mixnet_xl', 'MixNet-XL', '1907.09595', model_desc="My own scaling beyond paper's MixNet Large"), _entry('mixnet_xl', 'MixNet-XL', '1907.09595', model_desc="My own scaling beyond paper's MixNet Large"),
_entry('mixnet_l', 'MixNet-L', '1907.09595'), _entry('mixnet_l', 'MixNet-L', '1907.09595'),
_entry('mixnet_m', 'MixNet-M', '1907.09595'), _entry('mixnet_m', 'MixNet-M', '1907.09595'),
_entry('mixnet_s', 'MixNet-S', '1907.09595'), _entry('mixnet_s', 'MixNet-S', '1907.09595'),
_entry('fbnetc_100', 'FBNet-C', '1812.03443',
model_desc='Trained in PyTorch with RMSProp, exponential LR decay'),
_entry('mnasnet_100', 'MnasNet-B1', '1807.11626'), _entry('mnasnet_100', 'MnasNet-B1', '1807.11626'),
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
model_desc='Trained in PyTorch with SGD, cosine LR decay'),
_entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244', _entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244',
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching ' model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
'paper as closely as possible.'), 'paper as closely as possible.'),
_entry('resnet18', 'ResNet-18', '1812.01187'), _entry('resnet18', 'ResNet-18', '1812.01187'),
_entry('resnet26', 'ResNet-26', '1812.01187', model_desc='Block cfg of ResNet-34 w/ Bottleneck'), _entry('resnet26', 'ResNet-26', '1812.01187', model_desc='Block cfg of ResNet-34 w/ Bottleneck'),
_entry('resnet26d', 'ResNet-26-D', '1812.01187', _entry('resnet26d', 'ResNet-26-D', '1812.01187',
@ -103,7 +110,7 @@ model_list = [
_entry('resnext50d_32x4d', 'ResNeXt-50-D 32x4d', '1812.01187', _entry('resnext50d_32x4d', 'ResNeXt-50-D 32x4d', '1812.01187',
model_desc="'D' variant (3x3 deep stem w/ avg-pool downscale). Trained with " model_desc="'D' variant (3x3 deep stem w/ avg-pool downscale). Trained with "
"SGD w/ cosine LR decay, random-erasing (gaussian per-pixel noise) and label-smoothing"), "SGD w/ cosine LR decay, random-erasing (gaussian per-pixel noise) and label-smoothing"),
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
_entry('seresnet18', 'SE-ResNet-18', '1709.01507'), _entry('seresnet18', 'SE-ResNet-18', '1709.01507'),
_entry('seresnet34', 'SE-ResNet-34', '1709.01507'), _entry('seresnet34', 'SE-ResNet-34', '1709.01507'),
_entry('seresnext26_32x4d', 'SE-ResNeXt-26 32x4d', '1709.01507', _entry('seresnext26_32x4d', 'SE-ResNeXt-26 32x4d', '1709.01507',
@ -114,8 +121,9 @@ model_list = [
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered stem, and avg-pool in downsample layers.'), model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered stem, and avg-pool in downsample layers.'),
_entry('seresnext26tn_32x4d', 'SE-ResNeXt-26-TN 32x4d', '1812.01187', _entry('seresnext26tn_32x4d', 'SE-ResNeXt-26-TN 32x4d', '1812.01187',
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered narrow stem, and avg-pool in downsample layers.'), model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered narrow stem, and avg-pool in downsample layers.'),
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
model_desc='Trained in PyTorch with SGD, cosine LR decay'), _entry('skresnet18', 'SK-ResNet-18', '1903.06586'),
_entry('skresnext50_32x4d', 'SKNet-50', '1903.06586'),
_entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946', _entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946',
model_desc='Ported from official Google AI Tensorflow weights'), model_desc='Ported from official Google AI Tensorflow weights'),

@ -2,6 +2,9 @@
Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on
some tasks, especially fine-grained it seems. I may end up removing this impl.
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """

@ -1,4 +1,7 @@
""" Conditional Convolution """ PyTorch Conditionally Parameterized Convolution (CondConv)
Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
(https://arxiv.org/abs/1904.04971)
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """
@ -28,7 +31,7 @@ def get_condconv_initializer(initializer, num_experts, expert_shape):
class CondConv2d(nn.Module): class CondConv2d(nn.Module):
""" Conditional Convolution """ Conditionally Parameterized Convolution
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:

@ -42,7 +42,7 @@ class EcaModule(nn.Module):
"""Constructs an ECA module. """Constructs an ECA module.
Args: Args:
channel: Number of channels of the input feature map for use in adaptive kernel sizes channels: Number of channels of the input feature map for use in adaptive kernel sizes
for actual calculations according to channel. for actual calculations according to channel.
gamma, beta: when channel is given parameters of mapping function gamma, beta: when channel is given parameters of mapping function
refer to original paper https://arxiv.org/pdf/1910.03151.pdf refer to original paper https://arxiv.org/pdf/1910.03151.pdf

@ -1,4 +1,6 @@
""" Conditional Convolution """ PyTorch Mixed Convolution
Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """

@ -1,4 +1,6 @@
""" Selective Kernel Convolution Attention """ Selective Kernel Convolution/Attention
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """

@ -1,3 +1,9 @@
""" Selective Kernel Networks (ResNet base)
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
Hacked together by Ross Wightman
"""
import math import math
from torch import nn as nn from torch import nn as nn
@ -47,19 +53,11 @@ class SelectiveKernelBasic(nn.Module):
outplanes = planes * self.expansion outplanes = planes * self.expansion
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
_selective_first = True # FIXME temporary, for experiments self.conv1 = SelectiveKernelConv(
if _selective_first: inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
self.conv1 = SelectiveKernelConv( conv_kwargs['act_layer'] = None
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) self.conv2 = ConvBnAct(
conv_kwargs['act_layer'] = None first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
self.conv2 = ConvBnAct(
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
else:
self.conv1 = ConvBnAct(
inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs)
conv_kwargs['act_layer'] = None
self.conv2 = SelectiveKernelConv(
first_planes, outplanes, dilation=dilation, **conv_kwargs, **sk_kwargs)
self.se = create_attn(attn_layer, outplanes) self.se = create_attn(attn_layer, outplanes)
self.act = act_layer(inplace=True) self.act = act_layer(inplace=True)
self.downsample = downsample self.downsample = downsample
@ -222,7 +220,7 @@ def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
@register_model @register_model
def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
the SKNet50 model in the Select Kernel Paper the SKNet-50 model in the Select Kernel Paper
""" """
default_cfg = default_cfgs['skresnext50_32x4d'] default_cfg = default_cfgs['skresnext50_32x4d']
model = ResNet( model = ResNet(

Loading…
Cancel
Save