Update comments for Selective Kernel and DropBlock/Path impl, add skresnet34 weights

pull/88/head
Ross Wightman 5 years ago
parent 569419b38d
commit f1d5f8a6c4

@ -12,6 +12,6 @@ from .eca import EcaModule, CecaModule
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .drop import DropBlock2d, DropPath
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model

@ -2,6 +2,16 @@
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
Papers:
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
Code:
DropBlock impl inspired by two Tensorflow impl that I liked:
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
Hacked together by Ross Wightman
"""
import torch
@ -11,9 +21,15 @@ import numpy as np
import math
def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False):
def drop_block_2d(x, drop_prob=0.1, training=False, block_size=7, gamma_scale=1.0, drop_with_noise=False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
runs with success, but needs further validation and possibly optimization for lower runtime impact.
"""
if drop_prob == 0. or not training:
return x
_, _, height, width = x.shape
total_size = width * height
clipped_block_size = min(block_size, min(width, height))
@ -60,14 +76,21 @@ class DropBlock2d(nn.Module):
self.with_noise = with_noise
def forward(self, x):
if not self.training or not self.drop_prob:
return x
return drop_block_2d(x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise)
return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise)
def drop_path(x, drop_prob=0., training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
def drop_path(x, drop_prob=0.):
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
@ -76,13 +99,11 @@ def drop_path(x, drop_prob=0.):
class DropPath(nn.ModuleDict):
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
if not self.training or not self.drop_prob:
return x
return drop_path(x, self.drop_prob)
return drop_path(x, self.drop_prob, self.training)

@ -21,6 +21,11 @@ def _kernel_valid(k):
class SelectiveKernelAttn(nn.Module):
def __init__(self, channels, num_paths=2, attn_channels=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
""" Selective Kernel Attention Module
Selective Kernel attention mechanism factored out into its own module.
"""
super(SelectiveKernelAttn, self).__init__()
self.num_paths = num_paths
self.pool = nn.AdaptiveAvgPool2d(1)
@ -48,8 +53,33 @@ class SelectiveKernelConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
""" Selective Kernel Convolution Module
As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
Largest change is the input split, which divides the input channels across each convolution path, this can
be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps
the parameter count from ballooning when the convolutions themselves don't have groups, but still provides
a noteworthy increase in performance over similar param count models without this attention layer. -Ross W
Args:
in_channels (int): module input (feature) channel count
out_channels (int): module output (feature) channel count
kernel_size (int, list): kernel size for each convolution branch
stride (int): stride for convolutions
dilation (int): dilation for module as a whole, impacts dilation of each branch
groups (int): number of groups for each branch
attn_reduction (int, float): reduction factor for attention features
min_attn_channels (int): minimum attention feature channels
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
can be viewed as grouping by path, output expands to module out_channels count
drop_block (nn.Module): drop block module
act_layer (nn.Module): activation layer to use
norm_layer (nn.Module): batchnorm/norm layer to use
"""
super(SelectiveKernelConv, self).__init__()
kernel_size = kernel_size or [3, 5]
kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
_kernel_valid(kernel_size)
if not isinstance(kernel_size, list):
kernel_size = [kernel_size] * 2

@ -382,7 +382,7 @@ class ResNet(nn.Module):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Feature Blocks
dp = DropPath(drop_path_rate) if drop_block_rate else None
dp = DropPath(drop_path_rate) if drop_path_rate else None
db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None
db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4

@ -2,6 +2,10 @@
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268)
and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer
to the original paper with some modifications of my own to better balance param count vs accuracy.
Hacked together by Ross Wightman
"""
import math
@ -29,7 +33,8 @@ def _cfg(url='', **kwargs):
default_cfgs = {
'skresnet18': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'),
'skresnet34': _cfg(url=''),
'skresnet34': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'),
'skresnet50': _cfg(),
'skresnet50d': _cfg(),
'skresnext50_32x4d': _cfg(

Loading…
Cancel
Save