Merge branch 'master' into densenet_update_and_more

pull/155/head
Ross Wightman 5 years ago
commit 7df83258c9

@ -0,0 +1,42 @@
name: Python tests
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
test:
name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }}
strategy:
matrix:
os: [ubuntu-latest, macOS-latest]
python: ['3.8']
torch: ['1.5.0']
torchvision: ['0.6.0']
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python }}
- name: Install testing dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-timeout
- name: Install torch on mac
if: startsWith(matrix.os, 'macOS')
run: pip install torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
- name: Install torch on ubuntu
if: startsWith(matrix.os, 'ubuntu')
run: pip install torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install requirements
run: |
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11
- name: Run tests
run: |
pytest -vv --durations=0 ./tests

@ -2,6 +2,9 @@
## What's New ## What's New
### May 12, 2020
* Add ResNeSt models (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955))
### May 3, 2020 ### May 3, 2020
* Pruned EfficientNet B1, B2, and B3 (https://arxiv.org/abs/2002.08258) contributed by [Yonathan Aflalo](https://github.com/yoniaflalo) * Pruned EfficientNet B1, B2, and B3 (https://arxiv.org/abs/2002.08258) contributed by [Yonathan Aflalo](https://github.com/yoniaflalo)
@ -70,41 +73,6 @@
* Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section) * Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section)
* Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs. * Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs.
### Dec 30, 2019
* Merge [Dushyant Mehta's](https://github.com/mehtadushy) PR for SelecSLS (Selective Short and Long Range Skip Connections) networks. Good GPU memory consumption and throughput. Original: https://github.com/mehtadushy/SelecSLS-Pytorch
### Dec 28, 2019
* Add new model weights and training hparams (see Training Hparams section)
* `efficientnet_b3` - 81.5 top-1, 95.7 top-5 at default res/crop, 81.9, 95.8 at 320x320 1.0 crop-pct
* trained with RandAugment, ended up with an interesting but less than perfect result (see training section)
* `seresnext26d_32x4d`- 77.6 top-1, 93.6 top-5
* deep stem (32, 32, 64), avgpool downsample
* stem/dowsample from bag-of-tricks paper
* `seresnext26t_32x4d`- 78.0 top-1, 93.7 top-5
* deep tiered stem (24, 48, 64), avgpool downsample (a modified 'D' variant)
* stem sizing mods from Jeremy Howard and fastai devs discussing ResNet architecture experiments
### Dec 23, 2019
* Add RandAugment trained MixNet-XL weights with 80.48 top-1.
* `--dist-bn` argument added to train.py, will distribute BN stats between nodes after each train epoch, before eval
### Dec 4, 2019
* Added weights from the first training from scratch of an EfficientNet (B2) with my new RandAugment implementation. Much better than my previous B2 and very close to the official AdvProp ones (80.4 top-1, 95.08 top-5).
### Nov 29, 2019
* Brought EfficientNet and MobileNetV3 up to date with my https://github.com/rwightman/gen-efficientnet-pytorch code. Torchscript and ONNX export compat excluded.
* AdvProp weights added
* Official TF MobileNetv3 weights added
* EfficientNet and MobileNetV3 hook based 'feature extraction' classes added. Will serve as basis for using models as backbones in obj detection/segmentation tasks. Lots more to be done here...
* HRNet classification models and weights added from https://github.com/HRNet/HRNet-Image-Classification
* Consistency in global pooling, `reset_classifer`, and `forward_features` across models
* `forward_features` always returns unpooled feature maps now
* Reasonable chance I broke something... let me know
### Nov 22, 2019
* Add ImageNet training RandAugment implementation alongside AutoAugment. PyTorch Transform compatible format, using PIL. Currently training two EfficientNet models from scratch with promising results... will update.
* `drop-connect` cmd line arg finally added to `train.py`, no need to hack model fns. Works for efficientnet/mobilenetv3 based models, ignored otherwise.
## Introduction ## Introduction
For each competition, personal, or freelance project involving images + Convolution Neural Networks, I build on top of an evolving collection of code and models. This repo contains a (somewhat) cleaned up and paired down iteration of that code. Hopefully it'll be of use to others. For each competition, personal, or freelance project involving images + Convolution Neural Networks, I build on top of an evolving collection of code and models. This repo contains a (somewhat) cleaned up and paired down iteration of that code. Hopefully it'll be of use to others.
@ -130,6 +98,7 @@ Included models:
* Instagram trained / ImageNet tuned ResNeXt101-32x8d to 32x48d from from [facebookresearch](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/) * Instagram trained / ImageNet tuned ResNeXt101-32x8d to 32x48d from from [facebookresearch](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/)
* Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169) * Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169)
* Selective Kernel (SK) Nets (https://arxiv.org/abs/1903.06586) * Selective Kernel (SK) Nets (https://arxiv.org/abs/1903.06586)
* ResNeSt (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955)
* DLA * DLA
* Original (https://github.com/ucbdrive/dla, https://arxiv.org/abs/1707.06484) * Original (https://github.com/ucbdrive/dla, https://arxiv.org/abs/1707.06484)
* Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169) * Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169)

@ -0,0 +1,19 @@
import pytest
import torch
from timm import list_models, create_model
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*'))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False)
model.eval()
inputs = torch.randn((batch_size, *model.default_cfg['input_size']))
outputs = model(inputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

@ -18,6 +18,7 @@ from .dla import *
from .hrnet import * from .hrnet import *
from .sknet import * from .sknet import *
from .tresnet import * from .tresnet import *
from .resnest import *
from .registry import * from .registry import *
from .factory import create_model from .factory import create_model

@ -1,121 +1,561 @@
from torchvision.models import Inception3 import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .registry import register_model from .registry import register_model
from .layers import trunc_normal_, SelectAdaptivePool2d
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'conv1', 'classifier': 'fc',
**kwargs
}
__all__ = []
default_cfgs = { default_cfgs = {
# original PyTorch weights, ported from Tensorflow but modified # original PyTorch weights, ported from Tensorflow but modified
'inception_v3': { 'inception_v3': _cfg(
'url': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
'input_size': (3, 299, 299), has_aux=True), # checkpoint has aux logit layer weights
'crop_pct': 0.875,
'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, # also works well enough with resnet defaults
'std': IMAGENET_INCEPTION_STD, # also works well enough with resnet defaults
'num_classes': 1000,
'first_conv': 'conv0',
'classifier': 'fc'
},
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
'tf_inception_v3': { 'tf_inception_v3': _cfg(
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
'input_size': (3, 299, 299), num_classes=1001, has_aux=False),
'crop_pct': 0.875,
'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN,
'std': IMAGENET_INCEPTION_STD,
'num_classes': 1001,
'first_conv': 'conv0',
'classifier': 'fc'
},
# my port of Tensorflow adversarially trained Inception V3 from # my port of Tensorflow adversarially trained Inception V3 from
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
'adv_inception_v3': { 'adv_inception_v3': _cfg(
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
'input_size': (3, 299, 299), num_classes=1001, has_aux=False),
'crop_pct': 0.875,
'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN,
'std': IMAGENET_INCEPTION_STD,
'num_classes': 1001,
'first_conv': 'conv0',
'classifier': 'fc'
},
# from gluon pretrained models, best performing in terms of accuracy/loss metrics # from gluon pretrained models, best performing in terms of accuracy/loss metrics
# https://gluon-cv.mxnet.io/model_zoo/classification.html # https://gluon-cv.mxnet.io/model_zoo/classification.html
'gluon_inception_v3': { 'gluon_inception_v3': _cfg(
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
'input_size': (3, 299, 299), mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
'crop_pct': 0.875, std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
'interpolation': 'bicubic', has_aux=False,
'mean': IMAGENET_DEFAULT_MEAN, # also works well with inception defaults )
'std': IMAGENET_DEFAULT_STD, # also works well with inception defaults
'num_classes': 1000,
'first_conv': 'conv0',
'classifier': 'fc'
}
} }
def _assert_default_kwargs(kwargs): class InceptionV3Aux(nn.Module):
# for imported models (ie torchvision) without capability to change these params, """InceptionV3 with AuxLogits
# make sure they aren't being set to non-defaults """
assert kwargs.pop('global_pool', 'avg') == 'avg'
assert kwargs.pop('drop_rate', 0.) == 0. def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
super(InceptionV3Aux, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
if inception_blocks is None:
inception_blocks = [
BasicConv2d, InceptionA, InceptionB, InceptionC,
InceptionD, InceptionE, InceptionAux
]
assert len(inception_blocks) == 7
conv_block = inception_blocks[0]
inception_a = inception_blocks[1]
inception_b = inception_blocks[2]
inception_c = inception_blocks[3]
inception_d = inception_blocks[4]
inception_e = inception_blocks[5]
inception_aux = inception_blocks[6]
self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
self.Mixed_5b = inception_a(192, pool_features=32)
self.Mixed_5c = inception_a(256, pool_features=64)
self.Mixed_5d = inception_a(288, pool_features=64)
self.Mixed_6a = inception_b(288)
self.Mixed_6b = inception_c(768, channels_7x7=128)
self.Mixed_6c = inception_c(768, channels_7x7=160)
self.Mixed_6d = inception_c(768, channels_7x7=160)
self.Mixed_6e = inception_c(768, channels_7x7=192)
self.AuxLogits = inception_aux(768, num_classes)
self.Mixed_7a = inception_d(768)
self.Mixed_7b = inception_e(1280)
self.Mixed_7c = inception_e(2048)
self.num_features = 2048
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
trunc_normal_(m.weight, std=stddev)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
aux = self.AuxLogits(x) if self.training else None
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8
return x, aux
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
if self.num_classes > 0:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
else:
self.fc = nn.Identity()
def forward(self, x):
x, aux = self.forward_features(x)
x = self.global_pool(x).flatten(1)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x, aux
class InceptionV3(nn.Module):
"""Inception-V3 with no AuxLogits
FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns
"""
def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
super(InceptionV3, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
if inception_blocks is None:
inception_blocks = [
BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE]
assert len(inception_blocks) >= 6
conv_block = inception_blocks[0]
inception_a = inception_blocks[1]
inception_b = inception_blocks[2]
inception_c = inception_blocks[3]
inception_d = inception_blocks[4]
inception_e = inception_blocks[5]
self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
self.Mixed_5b = inception_a(192, pool_features=32)
self.Mixed_5c = inception_a(256, pool_features=64)
self.Mixed_5d = inception_a(288, pool_features=64)
self.Mixed_6a = inception_b(288)
self.Mixed_6b = inception_c(768, channels_7x7=128)
self.Mixed_6c = inception_c(768, channels_7x7=160)
self.Mixed_6d = inception_c(768, channels_7x7=160)
self.Mixed_6e = inception_c(768, channels_7x7=192)
self.Mixed_7a = inception_d(768)
self.Mixed_7b = inception_e(1280)
self.Mixed_7c = inception_e(2048)
self.num_features = 2048
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(2048, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
trunc_normal_(m.weight, std=stddev)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8
return x
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
if self.num_classes > 0:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
else:
self.fc = nn.Identity()
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x
class InceptionA(nn.Module):
def __init__(self, in_channels, pool_features, conv_block=None):
super(InceptionA, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
def _forward(self, x):
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionB(nn.Module):
def __init__(self, in_channels, conv_block=None):
super(InceptionB, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
def _forward(self, x):
branch3x3 = self.branch3x3(x)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
outputs = [branch3x3, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionC(nn.Module):
def __init__(self, in_channels, channels_7x7, conv_block=None):
super(InceptionC, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
c7 = channels_7x7
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
def _forward(self, x):
branch1x1 = self.branch1x1(x)
branch7x7 = self.branch7x7_1(x)
branch7x7 = self.branch7x7_2(branch7x7)
branch7x7 = self.branch7x7_3(branch7x7)
branch7x7dbl = self.branch7x7dbl_1(x)
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionD(nn.Module):
def __init__(self, in_channels, conv_block=None):
super(InceptionD, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
def _forward(self, x):
branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)
branch7x7x3 = self.branch7x7x3_1(x)
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
outputs = [branch3x3, branch7x7x3, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionE(nn.Module):
def __init__(self, in_channels, conv_block=None):
super(InceptionE, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
def _forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionAux(nn.Module):
def __init__(self, in_channels, num_classes, conv_block=None):
super(InceptionAux, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
self.conv1 = conv_block(128, 768, kernel_size=5)
self.conv1.stddev = 0.01
self.fc = nn.Linear(768, num_classes)
self.fc.stddev = 0.001
def forward(self, x):
# N x 768 x 17 x 17
x = F.avg_pool2d(x, kernel_size=5, stride=3)
# N x 768 x 5 x 5
x = self.conv0(x)
# N x 128 x 5 x 5
x = self.conv1(x)
# N x 768 x 1 x 1
# Adaptive average pooling
x = F.adaptive_avg_pool2d(x, (1, 1))
# N x 768 x 1 x 1
x = torch.flatten(x, 1)
# N x 768
x = self.fc(x)
# N x 1000
return x
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True)
def _inception_v3(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant]
if kwargs.pop('features_only', False):
assert False, 'Not Implemented' # TODO
load_strict = False
model_kwargs.pop('num_classes', 0)
model_class = InceptionV3
else:
aux_logits = kwargs.pop('aux_logits', False)
if aux_logits:
model_class = InceptionV3Aux
load_strict = default_cfg['has_aux']
else:
model_class = InceptionV3
load_strict = not default_cfg['has_aux']
model = model_class(**kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0),
in_chans=kwargs.get('in_chans', 3),
strict=load_strict)
return model
@register_model @register_model
def inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def inception_v3(pretrained=False, **kwargs):
# original PyTorch weights, ported from Tensorflow but modified # original PyTorch weights, ported from Tensorflow but modified
default_cfg = default_cfgs['inception_v3'] model = _inception_v3('inception_v3', pretrained=pretrained, **kwargs)
assert in_chans == 3
_assert_default_kwargs(kwargs)
model = Inception3(num_classes=num_classes, aux_logits=True, transform_input=False)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
model.default_cfg = default_cfg
return model return model
@register_model @register_model
def tf_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def tf_inception_v3(pretrained=False, **kwargs):
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
default_cfg = default_cfgs['tf_inception_v3'] model = _inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs)
assert in_chans == 3
_assert_default_kwargs(kwargs)
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
model.default_cfg = default_cfg
return model return model
@register_model @register_model
def adv_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def adv_inception_v3(pretrained=False, **kwargs):
# my port of Tensorflow adversarially trained Inception V3 from # my port of Tensorflow adversarially trained Inception V3 from
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
default_cfg = default_cfgs['adv_inception_v3'] model = _inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs)
assert in_chans == 3
_assert_default_kwargs(kwargs)
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
model.default_cfg = default_cfg
return model return model
@register_model @register_model
def gluon_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def gluon_inception_v3(pretrained=False, **kwargs):
# from gluon pretrained models, best performing in terms of accuracy/loss metrics # from gluon pretrained models, best performing in terms of accuracy/loss metrics
# https://gluon-cv.mxnet.io/model_zoo/classification.html # https://gluon-cv.mxnet.io/model_zoo/classification.html
default_cfg = default_cfgs['gluon_inception_v3'] model = _inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs)
assert in_chans == 3
_assert_default_kwargs(kwargs)
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
model.default_cfg = default_cfg
return model return model

@ -22,3 +22,4 @@ from .blur_pool import BlurPool2d
from .norm_act import BatchNormAct2d from .norm_act import BatchNormAct2d
from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .create_norm_act import create_norm_act from .create_norm_act import create_norm_act
from .weight_init import trunc_normal_

@ -22,44 +22,89 @@ import math
def drop_block_2d( def drop_block_2d(
x, drop_prob: float = 0.1, training: bool = False, block_size: int = 7, x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
gamma_scale: float = 1.0, drop_with_noise: bool = False): with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
runs with success, but needs further validation and possibly optimization for lower runtime impact. runs with success, but needs further validation and possibly optimization for lower runtime impact.
""" """
if drop_prob == 0. or not training: B, C, H, W = x.shape
return x total_size = W * H
_, _, height, width = x.shape clipped_block_size = min(block_size, min(W, H))
total_size = width * height
clipped_block_size = min(block_size, min(width, height))
# seed_drop_rate, the gamma parameter # seed_drop_rate, the gamma parameter
seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(width - block_size + 1) * (W - block_size + 1) * (H - block_size + 1))
(height - block_size + 1))
# Forces the block to be inside the feature map. # Forces the block to be inside the feature map.
w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device)) w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \ valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2)) ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
valid_block = torch.reshape(valid_block, (1, 1, height, width)).float() valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
uniform_noise = torch.rand_like(x, dtype=torch.float32) if batchwise:
block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float() # one mask for whole batch, quite a bit faster
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
else:
uniform_noise = torch.rand_like(x)
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
block_mask = -F.max_pool2d( block_mask = -F.max_pool2d(
-block_mask, -block_mask,
kernel_size=clipped_block_size, # block_size, ??? kernel_size=clipped_block_size, # block_size,
stride=1, stride=1,
padding=clipped_block_size // 2) padding=clipped_block_size // 2)
if drop_with_noise: if with_noise:
normal_noise = torch.randn_like(x) normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
x = x * block_mask + normal_noise * (1 - block_mask) if inplace:
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
else:
x = x * block_mask + normal_noise * (1 - block_mask)
else:
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
if inplace:
x.mul_(block_mask * normalize_scale)
else:
x = x * block_mask * normalize_scale
return x
def drop_block_fast_2d(
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
block mask at edges.
"""
B, C, H, W = x.shape
total_size = W * H
clipped_block_size = min(block_size, min(W, H))
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(W - block_size + 1) * (H - block_size + 1))
if batchwise:
# one mask for whole batch, quite a bit faster
block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
else:
# mask per batch element
block_mask = torch.rand_like(x) < gamma
block_mask = F.max_pool2d(
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
if with_noise:
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
if inplace:
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
else:
x = x * (1. - block_mask) + normal_noise * block_mask
else: else:
normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7) block_mask = 1 - block_mask
x = x * block_mask * normalize_scale normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
if inplace:
x.mul_(block_mask * normalize_scale)
else:
x = x * block_mask * normalize_scale
return x return x
@ -70,15 +115,28 @@ class DropBlock2d(nn.Module):
drop_prob=0.1, drop_prob=0.1,
block_size=7, block_size=7,
gamma_scale=1.0, gamma_scale=1.0,
with_noise=False): with_noise=False,
inplace=False,
batchwise=False,
fast=True):
super(DropBlock2d, self).__init__() super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
self.gamma_scale = gamma_scale self.gamma_scale = gamma_scale
self.block_size = block_size self.block_size = block_size
self.with_noise = with_noise self.with_noise = with_noise
self.inplace = inplace
self.batchwise = batchwise
self.fast = fast # FIXME finish comparisons of fast vs not
def forward(self, x): def forward(self, x):
return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise) if not self.training or not self.drop_prob:
return x
if self.fast:
return drop_block_fast_2d(
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
else:
return drop_block_2d(
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
def drop_path(x, drop_prob: float = 0., training: bool = False): def drop_path(x, drop_prob: float = 0., training: bool = False):

@ -0,0 +1,80 @@
""" Split Attention Conv2d (for ResNeSt Models)
Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
"""
import torch
import torch.nn.functional as F
from torch import nn
class RadixSoftmax(nn.Module):
def __init__(self, radix, cardinality):
super(RadixSoftmax, self).__init__()
self.radix = radix
self.cardinality = cardinality
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
class SplitAttnConv2d(nn.Module):
"""Split-Attention Conv2d
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
super(SplitAttnConv2d, self).__init__()
self.radix = radix
self.drop_block = drop_block
mid_chs = out_channels * radix
attn_chs = max(in_channels * radix // reduction_factor, 32)
self.conv = nn.Conv2d(
in_channels, mid_chs, kernel_size, stride, padding, dilation,
groups=groups * radix, bias=bias, **kwargs)
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
self.act0 = act_layer(inplace=True)
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
self.act1 = act_layer(inplace=True)
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
self.rsoftmax = RadixSoftmax(radix, groups)
def forward(self, x):
x = self.conv(x)
if self.bn0 is not None:
x = self.bn0(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act0(x)
B, RC, H, W = x.shape
if self.radix > 1:
x = x.reshape((B, self.radix, RC // self.radix, H, W))
x_gap = x.sum(dim=1)
else:
x_gap = x
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
x_gap = self.fc1(x_gap)
if self.bn1 is not None:
x_gap = self.bn1(x_gap)
x_gap = self.act1(x_gap)
x_attn = self.fc2(x_gap)
x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
if self.radix > 1:
out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
else:
out = x * x_attn
return out.contiguous()

@ -0,0 +1,60 @@
import torch
import math
import warnings
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)

@ -42,12 +42,14 @@ def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module='', pretrained=False): def list_models(filter='', module='', pretrained=False, exclude_filters=''):
""" Return list of available model names, sorted alphabetically """ Return list of available model names, sorted alphabetically
Args: Args:
filter (str) - Wildcard filter string that works with fnmatch filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
Example: Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
@ -58,7 +60,14 @@ def list_models(filter='', module='', pretrained=False):
else: else:
models = _model_entrypoints.keys() models = _model_entrypoints.keys()
if filter: if filter:
models = fnmatch.filter(models, filter) models = fnmatch.filter(models, filter) # include these models
if exclude_filters:
if not isinstance(exclude_filters, list):
exclude_filters = [exclude_filters]
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models):
models = set(models).difference(exclude_models)
if pretrained: if pretrained:
models = _model_has_pretrained.intersection(models) models = _model_has_pretrained.intersection(models)
return list(sorted(models, key=_natural_key)) return list(sorted(models, key=_natural_key))

@ -0,0 +1,264 @@
""" ResNeSt Models
Paper: `ResNeSt: Split-Attention Networks` - https://arxiv.org/abs/2004.08955
Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt by Hang Zhang
Modified for torchscript compat, and consistency with timm by Ross Wightman
"""
import math
import torch
import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropBlock2d
from .helpers import load_pretrained
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
from .layers.split_attn import SplitAttnConv2d
from .registry import register_model
from .resnet import ResNet
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv1', 'classifier': 'fc',
**kwargs
}
default_cfgs = {
'resnest14d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'),
'resnest26d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'),
'resnest50d': _cfg(
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'),
'resnest101e': _cfg(
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', input_size=(3, 256, 256)),
'resnest200e': _cfg(
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)),
'resnest269e': _cfg(
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)),
'resnest50d_4s2x40d': _cfg(
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_4s2x40d-41d14ed0.pth',
interpolation='bicubic'),
'resnest50d_1s4x24d': _cfg(
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_1s4x24d-d4a4f76f.pth',
interpolation='bicubic')
}
class ResNestBottleneck(nn.Module):
"""ResNet Bottleneck
"""
# pylint: disable=unused-argument
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(ResNestBottleneck, self).__init__()
assert reduce_first == 1 # not supported
assert attn_layer is None # not supported
assert aa_layer is None # TODO not yet supported
assert drop_path is None # TODO not yet supported
group_width = int(planes * (base_width / 64.)) * cardinality
first_dilation = first_dilation or dilation
if avd and (stride > 1 or is_first):
avd_stride = stride
stride = 1
else:
avd_stride = 0
self.radix = radix
self.drop_block = drop_block
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
self.bn1 = norm_layer(group_width)
self.act1 = act_layer(inplace=True)
self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
if self.radix >= 1:
self.conv2 = SplitAttnConv2d(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness
self.act2 = None
else:
self.conv2 = nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(group_width)
self.act2 = act_layer(inplace=True)
self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None
self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes*4)
self.act3 = act_layer(inplace=True)
self.downsample = downsample
def zero_init_last_bn(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
if self.drop_block is not None:
out = self.drop_block(out)
out = self.act1(out)
if self.avd_first is not None:
out = self.avd_first(out)
out = self.conv2(out)
if self.bn2 is not None:
out = self.bn2(out)
if self.drop_block is not None:
out = self.drop_block(out)
out = self.act2(out)
if self.avd_last is not None:
out = self.avd_last(out)
out = self.conv3(out)
out = self.bn3(out)
if self.drop_block is not None:
out = self.drop_block(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.act3(out)
return out
@register_model
def resnest14d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" ResNeSt-14d model. Weights ported from GluonCV.
"""
default_cfg = default_cfgs['resnest14d']
model = ResNet(
ResNestBottleneck, [1, 1, 1, 1], num_classes=num_classes, in_chans=in_chans,
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" ResNeSt-26d model. Weights ported from GluonCV.
"""
default_cfg = default_cfgs['resnest26d']
model = ResNet(
ResNestBottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnest50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955
Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample.
"""
default_cfg = default_cfgs['resnest50d']
model = ResNet(
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnest101e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
"""
default_cfg = default_cfgs['resnest101e']
model = ResNet(
ResNestBottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans,
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnest200e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
"""
default_cfg = default_cfgs['resnest200e']
model = ResNet(
ResNestBottleneck, [3, 24, 36, 3], num_classes=num_classes, in_chans=in_chans,
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
"""
default_cfg = default_cfgs['resnest269e']
model = ResNet(
ResNestBottleneck, [3, 30, 48, 8], num_classes=num_classes, in_chans=in_chans,
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnest50d_4s2x40d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
"""
default_cfg = default_cfgs['resnest50d_4s2x40d']
model = ResNet(
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
block_args=dict(radix=4, avd=True, avd_first=True), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnest50d_1s4x24d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
"""
default_cfg = default_cfgs['resnest50d_1s4x24d']
model = ResNet(
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
Loading…
Cancel
Save