Add Res2Net and DLA models w/ pretrained weights. Update sotabench.

pull/35/head
rwightman 5 years ago
parent 7fd0857bfe
commit adbf770f16

@ -19,13 +19,14 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
ttp=ttp,
args=args)
# NOTE For any original PyTorch models, I'll remove from this list when you add to sotabench to
# avoid overlap and confusion. Please contact me.
model_list = [
#_entry('adv_inception_v3', 'Adversarial Inception V3', ),
#_entry('densenet121'), # same weights as torchvision
#_entry('densenet161'), # same weights as torchvision
#_entry('densenet169'), # same weights as torchvision
#_entry('densenet201'), # same weights as torchvision
## Weights ported by myself from other frameworks or trained myself in PyTorch
_entry('adv_inception_v3', 'Adversarial Inception V3', '1611.01236',
model_desc='Ported from official Tensorflow weights'),
_entry('ens_adv_inception_resnet_v2', 'Ensemble Adversarial Inception V3', '1705.07204',
model_desc='Ported from official Tensorflow weights'),
_entry('dpn68', 'DPN-68 (224x224)', '1707.01629'),
_entry('dpn68b', 'DPN-68b (224x224)', '1707.01629'),
_entry('dpn92', 'DPN-92 (224x224)', '1707.01629'),
@ -45,74 +46,57 @@ model_list = [
_entry('efficientnet_b0', 'EfficientNet-B0', '1905.11946'),
_entry('efficientnet_b1', 'EfficientNet-B1', '1905.11946'),
_entry('efficientnet_b2', 'EfficientNet-B2', '1905.11946'),
#_entry('ens_adv_inception_resnet_v2', 'Ensemble Adversarial Inception V3'),
_entry('fbnetc_100', 'FBNet-C', '1812.03443'),
_entry('gluon_inception_v3', 'Inception V3', '1512.00567'),
_entry('fbnetc_100', 'FBNet-C', '1812.03443',
model_desc='Trained in PyTorch with RMSProp, exponential LR decay'),
_entry('gluon_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet18_v1b', 'ResNet-18', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet34_v1b', 'ResNet-34', '1812.01187'),
_entry('gluon_resnet50_v1b', 'ResNet-50', '1812.01187'),
_entry('gluon_resnet50_v1c', 'ResNet-50-C', '1812.01187'),
_entry('gluon_resnet50_v1d', 'ResNet-50-D', '1812.01187'),
_entry('gluon_resnet50_v1s', 'ResNet-50-S', '1812.01187'),
_entry('gluon_resnet101_v1b', 'ResNet-101', '1812.01187'),
_entry('gluon_resnet101_v1c', 'ResNet-101-C', '1812.01187'),
_entry('gluon_resnet101_v1d', 'ResNet-101-D', '1812.01187'),
_entry('gluon_resnet101_v1s', 'ResNet-101-S', '1812.01187'),
_entry('gluon_resnet152_v1b', 'ResNet-152', '1812.01187'),
_entry('gluon_resnet152_v1c', 'ResNet-152-C', '1812.01187'),
_entry('gluon_resnet152_v1d', 'ResNet-152-D', '1812.01187'),
_entry('gluon_resnet152_v1s', 'ResNet-152-S', '1812.01187'),
_entry('gluon_resnext50_32x4d', 'ResNeXt-50 32x4d', '1812.01187'),
_entry('gluon_resnext101_32x4d', 'ResNeXt-101 32x4d', '1812.01187'),
_entry('gluon_resnext101_64x4d', 'ResNeXt-101 64x4d', '1812.01187'),
_entry('gluon_senet154', 'SENet-154', '1812.01187'),
_entry('gluon_seresnext50_32x4d', 'SE-ResNeXt-50 32x4d', '1812.01187'),
_entry('gluon_seresnext101_32x4d', 'SE-ResNeXt-101 32x4d', '1812.01187'),
_entry('gluon_seresnext101_64x4d', 'SE-ResNeXt-101 64x4d', '1812.01187'),
_entry('gluon_xception65', 'Modified Aligned Xception', '1802.02611', batch_size=BATCH_SIZE//2),
_entry('ig_resnext101_32x8d', 'ResNeXt-101 32x8d', '1805.00932'),
_entry('ig_resnext101_32x16d', 'ResNeXt-101 32x16d', '1805.00932'),
_entry('ig_resnext101_32x32d', 'ResNeXt-101 32x32d', '1805.00932', batch_size=BATCH_SIZE//2),
_entry('ig_resnext101_32x48d', 'ResNeXt-101 32x48d', '1805.00932', batch_size=BATCH_SIZE//4),
_entry('inception_resnet_v2', 'Inception ResNet V2', '1602.07261'),
#_entry('inception_v3', paper_model_name='Inception V3', ), # same weights as torchvision
_entry('inception_v4', 'Inception V4', '1602.07261'),
_entry('gluon_resnet34_v1b', 'ResNet-34', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet50_v1b', 'ResNet-50', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet50_v1c', 'ResNet-50-C', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet50_v1d', 'ResNet-50-D', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet50_v1s', 'ResNet-50-S', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet101_v1b', 'ResNet-101', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet101_v1c', 'ResNet-101-C', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet101_v1d', 'ResNet-101-D', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet101_v1s', 'ResNet-101-S', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet152_v1b', 'ResNet-152', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet152_v1c', 'ResNet-152-C', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet152_v1d', 'ResNet-152-D', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnet152_v1s', 'ResNet-152-S', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnext50_32x4d', 'ResNeXt-50 32x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnext101_32x4d', 'ResNeXt-101 32x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_resnext101_64x4d', 'ResNeXt-101 64x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_senet154', 'SENet-154', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_seresnext50_32x4d', 'SE-ResNeXt-50 32x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_seresnext101_32x4d', 'SE-ResNeXt-101 32x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_seresnext101_64x4d', 'SE-ResNeXt-101 64x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
_entry('gluon_xception65', 'Modified Aligned Xception', '1802.02611', batch_size=BATCH_SIZE//2,
model_desc='Ported from GluonCV Model Zoo'),
_entry('mixnet_xl', 'MixNet-XL', '1907.09595', model_desc="My own scaling beyond paper's MixNet Large"),
_entry('mixnet_l', 'MixNet-L', '1907.09595'),
_entry('mixnet_m', 'MixNet-M', '1907.09595'),
_entry('mixnet_s', 'MixNet-S', '1907.09595'),
_entry('mnasnet_100', 'MnasNet-B1', '1807.11626'),
_entry('mobilenetv3_100', 'MobileNet V3(1.0)', '1905.02244',
model_desc='Trained from scratch in PyTorch with RMSProp, exponential LR decay, and hyper-params matching'
' paper as closely as possible.'),
_entry('nasnetalarge', 'NASNet-A Large', '1707.07012', batch_size=BATCH_SIZE//4),
_entry('pnasnet5large', 'PNASNet-5', '1712.00559', batch_size=BATCH_SIZE//4),
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
'paper as closely as possible.'),
_entry('resnet18', 'ResNet-18', '1812.01187'),
_entry('resnet26', 'ResNet-26', '1812.01187'),
_entry('resnet26d', 'ResNet-26-D', '1812.01187'),
_entry('resnet26', 'ResNet-26', '1812.01187', model_desc='Block cfg of ResNet-34 w/ Bottleneck'),
_entry('resnet26d', 'ResNet-26-D', '1812.01187',
model_desc='Block cfg of ResNet-34 w/ Bottleneck, deep stem, and avg-pool in downsample layers.'),
_entry('resnet34', 'ResNet-34', '1812.01187'),
_entry('resnet50', 'ResNet-50', '1812.01187'),
#_entry('resnet101', , ), # same weights as torchvision
#_entry('resnet152', , ), # same weights as torchvision
_entry('resnext50_32x4d', 'ResNeXt-50 32x4d', '1812.01187'),
_entry('resnext50d_32x4d', 'ResNeXt-50-D 32x4d', '1812.01187',
model_desc="""'D' variant (3x3 deep stem w/ avg-pool downscale)
Trained with:
* SGD w/ cosine LR decay
* Random-erasing (gaussian per-pixel noise)
* Label-smoothing
"""),
#_entry('resnext101_32x8d', ), # same weights as torchvision
model_desc="'D' variant (3x3 deep stem w/ avg-pool downscale). Trained with "
"SGD w/ cosine LR decay, random-erasing (gaussian per-pixel noise) and label-smoothing"),
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
_entry('senet154', 'SENet-154', '1709.01507'),
_entry('seresnet18', 'SE-ResNet-18', '1709.01507'),
_entry('seresnet34', 'SE-ResNet-34', '1709.01507'),
_entry('seresnet50', 'SE-ResNet-50', '1709.01507'),
_entry('seresnet101', 'SE-ResNet-101', '1709.01507'),
_entry('seresnet152', 'SE-ResNet-152', '1709.01507'),
_entry('seresnext26_32x4d', 'SE-ResNeXt-26 32x4d', '1709.01507'),
_entry('seresnext50_32x4d', 'SE-ResNeXt-50 32x4d', '1709.01507'),
_entry('seresnext101_32x4d', 'SE-ResNeXt-101 32x4d', '1709.01507'),
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877'),
_entry('seresnext26_32x4d', 'SE-ResNeXt-26 32x4d', '1709.01507',
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep stem, and avg-pool in downsample layers.'),
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
model_desc='Trained in PyTorch with SGD, cosine LR decay'),
_entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946',
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_b1', 'EfficientNet-B1 (AutoAugment)', '1905.11946',
@ -135,18 +119,76 @@ model_list = [
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_el', 'EfficientNet-EdgeTPU-L', '1905.11946', batch_size=BATCH_SIZE//2,
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_inception_v3', 'Inception V3', '1512.00567'),
_entry('tf_mixnet_l', 'MixNet-L', '1907.09595'),
_entry('tf_mixnet_m', 'MixNet-M', '1907.09595'),
_entry('tf_mixnet_s', 'MixNet-S', '1907.09595'),
#_entry('tv_resnet34', , ), # same weights as torchvision
#_entry('tv_resnet50', , ), # same weights as torchvision
#_entry('tv_resnext50_32x4d', , ), # same weights as torchvision
#_entry('wide_resnet50_2' , ), # same weights as torchvision
#_entry('wide_resnet101_2', , ), # same weights as torchvision
_entry('tf_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from official Tensorflow weights'),
_entry('tf_mixnet_l', 'MixNet-L', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_mixnet_m', 'MixNet-M', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_mixnet_s', 'MixNet-S', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),
## Cadene ported weights (to remove if Cadene adds sotabench)
_entry('inception_resnet_v2', 'Inception ResNet V2', '1602.07261'),
_entry('inception_v4', 'Inception V4', '1602.07261'),
_entry('nasnetalarge', 'NASNet-A Large', '1707.07012', batch_size=BATCH_SIZE // 4),
_entry('pnasnet5large', 'PNASNet-5', '1712.00559', batch_size=BATCH_SIZE // 4),
_entry('seresnet50', 'SE-ResNet-50', '1709.01507'),
_entry('seresnet101', 'SE-ResNet-101', '1709.01507'),
_entry('seresnet152', 'SE-ResNet-152', '1709.01507'),
_entry('seresnext50_32x4d', 'SE-ResNeXt-50 32x4d', '1709.01507'),
_entry('seresnext101_32x4d', 'SE-ResNeXt-101 32x4d', '1709.01507'),
_entry('senet154', 'SENet-154', '1709.01507'),
_entry('xception', 'Xception', '1610.02357'),
]
## Torchvision weights
# _entry('densenet121'),
# _entry('densenet161'),
# _entry('densenet169'),
# _entry('densenet201'),
# _entry('inception_v3', paper_model_name='Inception V3', ),
# _entry('tv_resnet34', , ),
# _entry('tv_resnet50', , ),
# _entry('resnet101', , ),
# _entry('resnet152', , ),
# _entry('tv_resnext50_32x4d', , ),
# _entry('resnext101_32x8d', ),
# _entry('wide_resnet50_2' , ),
# _entry('wide_resnet101_2', , ),
## Facebook WSL weights
_entry('ig_resnext101_32x8d', 'ResNeXt-101 32x8d', '1805.00932'),
_entry('ig_resnext101_32x16d', 'ResNeXt-101 32x16d', '1805.00932'),
_entry('ig_resnext101_32x32d', 'ResNeXt-101 32x32d', '1805.00932', batch_size=BATCH_SIZE // 2),
_entry('ig_resnext101_32x48d', 'ResNeXt-101 32x48d', '1805.00932', batch_size=BATCH_SIZE // 4),
_entry('ig_resnext101_32x8d (288x288 Mean-Max Pooling)', 'ResNeXt-101 32x8d', '1805.00932',
ttp=True, args=dict(img_size=288)),
_entry('ig_resnext101_32x16d (288x288 Mean-Max Pooling)', 'ResNeXt-101 32x16d', '1805.00932',
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 2),
_entry('ig_resnext101_32x32d (288x288 Mean-Max Pooling)', 'ResNeXt-101 32x32d', '1805.00932',
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 4),
_entry('ig_resnext101_32x48d (288x288 Mean-Max Pooling)', 'ResNeXt-101 32x48d', '1805.00932',
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 8),
## DLA official impl weights (to remove if sotabench added to source)
_entry('dla34', 'DLA-34', '1707.06484'),
_entry('dla46_c', 'DLA-46-C', '1707.06484'),
_entry('dla46x_c', 'DLA-X-46-C', '1707.06484'),
_entry('dla60x_c', 'DLA-X-60-C', '1707.06484'),
_entry('dla60', 'DLA-60', '1707.06484'),
_entry('dla60x', 'DLA-X-60', '1707.06484'),
_entry('dla102', 'DLA-102', '1707.06484'),
_entry('dla102x', 'DLA-X-102', '1707.06484'),
_entry('dla102x2', 'DLA-X-102 64', '1707.06484'),
_entry('dla169', 'DLA-169', '1707.06484'),
## Res2Net official impl weights (to remove if sotabench added to source)
_entry('res2net50_26w_4s', 'Res2Net-50 26x4s', '1904.01169'),
_entry('res2net50_14w_8s', 'Res2Net-50 14x8s', '1904.01169'),
_entry('res2net50_26w_6s', 'Res2Net-50 26x6s', '1904.01169'),
_entry('res2net50_26w_8s', 'Res2Net-50 26x8s', '1904.01169'),
_entry('res2net50_48w_2s', 'Res2Net-50 48x2s', '1904.01169'),
_entry('res2net101_26w_4s', 'Res2NeXt-101 26x4s', '1904.01169'),
_entry('res2next50', 'Res2NeXt-50', '1904.01169'),
_entry('dla60_res2net', 'Res2Net-DLA-60', '1904.01169'),
_entry('dla60_res2next', 'Res2NeXt-DLA-60', '1904.01169'),
]
for m in model_list:
model_name = m['model']

@ -11,6 +11,8 @@ from .gen_efficientnet import *
from .inception_v3 import *
from .gluon_resnet import *
from .gluon_xception import *
from .res2net import *
from .dla import *
from .registry import *
from .factory import create_model

@ -0,0 +1,471 @@
""" Deep Layer Aggregation and DLA w/ Res2Net
DLA original adapted from Official Pytorch impl at:
DLA Paper: `Deep Layer Aggregation` - https://arxiv.org/abs/1707.06484
Res2Net additions from: https://github.com/gasvn/Res2Net/
Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['DLA']
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'base_layer.0', 'classifier': 'fc',
**kwargs
}
default_cfgs = {
'dla34': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla34-ba72cf86.pth'),
'dla46_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla46_c-2bfd52c3.pth'),
'dla46x_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla46x_c-d761bae7.pth'),
'dla60x_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60x_c-b870c45c.pth'),
'dla60': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60-24839fc4.pth'),
'dla60x': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60x-d15cacda.pth'),
'dla102': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102-d94d9790.pth'),
'dla102x': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102x-ad62be81.pth'),
'dla102x2': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102x2-262837b6.pth'),
'dla169': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla169-0914e092.pth'),
'dla60_res2net': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net_dla60_4s-d88db7f9.pth'),
'dla60_res2next': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next_dla60_4s-d327927b.pth'),
}
class DlaBasic(nn.Module):
"""DLA Basic"""
def __init__(self, inplanes, planes, stride=1, dilation=1, **_):
super(DlaBasic, self).__init__()
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
self.stride = stride
def forward(self, x, residual=None):
if residual is None:
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
class DlaBottleneck(nn.Module):
"""DLA/DLA-X Bottleneck"""
expansion = 2
def __init__(self, inplanes, outplanes, stride=1, dilation=1, cardinality=1, base_width=64):
super(DlaBottleneck, self).__init__()
self.stride = stride
mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
mid_planes = mid_planes // self.expansion
self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(mid_planes)
self.conv2 = nn.Conv2d(
mid_planes, mid_planes, kernel_size=3, stride=stride, padding=dilation,
bias=False, dilation=dilation, groups=cardinality)
self.bn2 = nn.BatchNorm2d(mid_planes)
self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(outplanes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, residual=None):
if residual is None:
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += residual
out = self.relu(out)
return out
class DlaBottle2neck(nn.Module):
""" Res2Net/Res2NeXT DLA Bottleneck
Adapted from https://github.com/gasvn/Res2Net/blob/master/dla.py
"""
expansion = 2
def __init__(self, inplanes, outplanes, stride=1, dilation=1, scale=4, cardinality=8, base_width=4):
super(DlaBottle2neck, self).__init__()
self.is_first = stride > 1
self.scale = scale
mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
mid_planes = mid_planes // self.expansion
self.width = mid_planes
self.conv1 = nn.Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(mid_planes * scale)
num_scale_convs = max(1, scale - 1)
convs = []
bns = []
for _ in range(num_scale_convs):
convs.append(nn.Conv2d(
mid_planes, mid_planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, groups=cardinality, bias=False))
bns.append(nn.BatchNorm2d(mid_planes))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
if self.is_first:
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(outplanes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, residual=None):
if residual is None:
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
spo = []
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
sp = spx[i] if i == 0 or self.is_first else sp + spx[i]
sp = conv(sp)
sp = bn(sp)
sp = self.relu(sp)
spo.append(sp)
if self.scale > 1 :
spo.append(self.pool(spx[-1]) if self.is_first else spx[-1])
out = torch.cat(spo, 1)
out = self.conv3(out)
out = self.bn3(out)
out += residual
out = self.relu(out)
return out
class DlaRoot(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, residual):
super(DlaRoot, self).__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.residual = residual
def forward(self, *x):
children = x
x = self.conv(torch.cat(x, 1))
x = self.bn(x)
if self.residual:
x += children[0]
x = self.relu(x)
return x
class DlaTree(nn.Module):
def __init__(self, levels, block, in_channels, out_channels, stride=1,
dilation=1, cardinality=1, base_width=64,
level_root=False, root_dim=0, root_kernel_size=1, root_residual=False):
super(DlaTree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width)
if levels == 1:
self.tree1 = block(in_channels, out_channels, stride, **cargs)
self.tree2 = block(out_channels, out_channels, 1, **cargs)
else:
cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual))
self.tree1 = DlaTree(
levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs)
self.tree2 = DlaTree(
levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs)
if levels == 1:
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else None
self.project = None
if in_channels != out_channels:
self.project = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels)
)
self.levels = levels
def forward(self, x, residual=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
residual = self.project(bottom) if self.project else bottom
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, residual)
if self.levels == 1:
x2 = self.tree2(x1)
x = self.root(x2, x1, *children)
else:
children.append(x1)
x = self.tree2(x1, children=children)
return x
class DLA(nn.Module):
def __init__(self, levels, channels, num_classes=1000, in_chans=3, cardinality=1, base_width=64,
block=DlaBottle2neck, residual_root=False, linear_root=False,
drop_rate=0.0, global_pool='avg'):
super(DLA, self).__init__()
self.channels = channels
self.num_classes = num_classes
self.cardinality = cardinality
self.base_width = base_width
self.drop_rate = drop_rate
self.base_layer = nn.Sequential(
nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(channels[0]),
nn.ReLU(inplace=True))
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
cargs = dict(cardinality=cardinality, base_width=base_width, root_residual=residual_root)
self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs)
self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs)
self.num_features = channels[-1]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes,
kernel_size=1, stride=1, padding=0, bias=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
modules = []
for i in range(convs):
modules.extend([
nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1,
padding=dilation, bias=False, dilation=dilation),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)])
inplanes = planes
return nn.Sequential(*modules)
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
del self.fc
if num_classes:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
else:
self.fc = None
def forward_features(self, x, pool=True):
x = self.base_layer(x)
x = self.level0(x)
x = self.level1(x)
x = self.level2(x)
x = self.level3(x)
x = self.level4(x)
x = self.level5(x)
if pool:
x = self.global_pool(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
x = x.flatten(1)
return x
@register_model
def dla60_res2net(pretrained=None, num_classes=1000, in_chans=3, **kwargs):
default_cfg = default_cfgs['dla60_res2net']
model = DLA(levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
block=DlaBottle2neck, cardinality=1, base_width=28,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def dla60_res2next(pretrained=None, num_classes=1000, in_chans=3, **kwargs):
default_cfg = default_cfgs['dla60_res2next']
model = DLA(levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
block=DlaBottle2neck, cardinality=8, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def dla34(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-34
default_cfg = default_cfgs['dla34']
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def dla46_c(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-46-C
default_cfg = default_cfgs['dla46_c']
model = DLA(levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
block=DlaBottleneck, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def dla46x_c(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-46-C
default_cfg = default_cfgs['dla46x_c']
model = DLA(levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
block=DlaBottleneck, cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def dla60x_c(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-60-C
default_cfg = default_cfgs['dla60x_c']
model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256],
block=DlaBottleneck, cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def dla60(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-60
default_cfg = default_cfgs['dla60']
model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def dla60x(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-60
default_cfg = default_cfgs['dla60x']
model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def dla102(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-102
default_cfg = default_cfgs['dla102']
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, residual_root=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def dla102x(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-102
default_cfg = default_cfgs['dla102x']
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def dla102x2(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-102 64
default_cfg = default_cfgs['dla102x2']
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def dla169(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-169
default_cfg = default_cfgs['dla169']
model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, residual_root=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -0,0 +1,242 @@
""" Res2Net and Res2NeXt
Adapted from Official Pytorch impl at: https://github.com/gasvn/Res2Net/
Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet import ResNet, SEModule
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = []
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv1', 'classifier': 'fc',
**kwargs
}
default_cfgs = {
'res2net50_26w_4s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth'),
'res2net50_48w_2s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth'),
'res2net50_14w_8s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth'),
'res2net50_26w_6s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth'),
'res2net50_26w_8s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth'),
'res2net101_26w_4s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth'),
'res2next50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth'),
}
class Bottle2neck(nn.Module):
""" Res2Net/Res2NeXT Bottleneck
Adapted from https://github.com/gasvn/Res2Net/blob/master/res2net.py
"""
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=26, scale=4, use_se=False,
norm_layer=None, dilation=1, previous_dilation=1, **_):
super(Bottle2neck, self).__init__()
assert dilation == 1 and previous_dilation == 1 # FIXME support dilation
self.scale = scale
self.is_first = True if stride > 1 or downsample is not None else False
self.num_scales = max(1, scale - 1)
width = int(math.floor(planes * (base_width / 64.0))) * cardinality
outplanes = planes * self.expansion
self.width = width
self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
self.bn1 = norm_layer(width * scale)
convs = []
bns = []
for i in range(self.num_scales):
convs.append(nn.Conv2d(
width, width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False))
bns.append(norm_layer(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
if self.is_first:
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
spo = []
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
sp = spx[i] if i == 0 or self.is_first else sp + spx[i]
sp = conv(sp)
sp = bn(sp)
sp = self.relu(sp)
spo.append(sp)
if self.scale > 1 :
spo.append(self.pool(spx[-1]) if self.is_first else spx[-1])
out = torch.cat(spo, 1)
out = self.conv3(out)
out = self.bn3(out)
if self.se is not None:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
@register_model
def res2net50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50 model.
Res2Net-50 refers to the Res2Net-50_26w_4s.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return res2net50_26w_4s(pretrained, num_classes, in_chans, **kwargs)
@register_model
def res2net50_26w_4s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net50_26w_4s']
res2net_block_args = dict(scale=4)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def res2net101_26w_4s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net101_26w_4s']
res2net_block_args = dict(scale=4)
model = ResNet(Bottle2neck, [3, 4, 23, 3], base_width=26,
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def res2net50_26w_6s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net50_26w_6s']
res2net_block_args = dict(scale=6)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def res2net50_26w_8s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net50_26w_8s']
res2net_block_args = dict(scale=8)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def res2net50_48w_2s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_48w_2s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net50_48w_2s']
res2net_block_args = dict(scale=2)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=48,
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def res2net50_14w_8s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_14w_8s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net50_14w_8s']
res2net_block_args = dict(scale=8)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=14, num_classes=num_classes, in_chans=in_chans,
block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def res2next50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Construct Res2NeXt-50 4s
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2next50']
res2net_block_args = dict(scale=4)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=4, cardinality=8,
num_classes=1000, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -262,7 +262,8 @@ class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
cardinality=1, base_width=64, stem_width=64, deep_stem=False,
block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False,
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', zero_init_last_bn=True):
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg',
zero_init_last_bn=True, block_args=dict()):
self.num_classes = num_classes
self.inplanes = stem_width * 2 if deep_stem else 64
self.cardinality = cardinality
@ -290,7 +291,7 @@ class ResNet(nn.Module):
dilation_3 = 2 if self.dilated else 1
dilation_4 = 4 if self.dilated else 1
largs = dict(use_se=use_se, reduce_first=block_reduce_first, norm_layer=norm_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size)
avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args)
self.layer1 = self._make_layer(block, 64, layers[0], stride=1, **largs)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, **largs)
self.layer3 = self._make_layer(block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, **largs)
@ -312,7 +313,7 @@ class ResNet(nn.Module):
nn.init.constant_(m.bias, 0.)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d):
use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d, **kwargs):
downsample = None
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
if stride != 1 or self.inplanes != planes * block.expansion:
@ -330,16 +331,15 @@ class ResNet(nn.Module):
downsample = nn.Sequential(*downsample_layers)
first_dilation = 1 if dilation in (1, 2) else 2
layers = [block(
self.inplanes, planes, stride, downsample,
bargs = dict(
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
use_se=use_se, dilation=first_dilation, previous_dilation=dilation, norm_layer=norm_layer)]
use_se=use_se, norm_layer=norm_layer, **kwargs)
layers = [block(
self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bargs)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(
self.inplanes, planes,
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
use_se=use_se, dilation=dilation, previous_dilation=dilation, norm_layer=norm_layer))
self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bargs))
return nn.Sequential(*layers)
@ -632,7 +632,8 @@ def ig_resnext101_32x8d(pretrained=True, num_classes=1000, in_chans=3, **kwargs)
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x8d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8,
num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -651,7 +652,8 @@ def ig_resnext101_32x16d(pretrained=True, num_classes=1000, in_chans=3, **kwargs
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x16d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16,
num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -670,7 +672,8 @@ def ig_resnext101_32x32d(pretrained=True, num_classes=1000, in_chans=3, **kwargs
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x32d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=32,
num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -689,7 +692,8 @@ def ig_resnext101_32x48d(pretrained=True, num_classes=1000, in_chans=3, **kwargs
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x48d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=48,
num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

Loading…
Cancel
Save