diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 828c20b2..f012c3cf 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -12,6 +12,6 @@ from .eca import EcaModule, CecaModule from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d -from .drop import DropBlock2d, DropPath +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index 46d5d20b..669dbf24 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -2,6 +2,16 @@ PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) + +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) + +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py + Hacked together by Ross Wightman """ import torch @@ -11,9 +21,15 @@ import numpy as np import math -def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False): +def drop_block_2d(x, drop_prob=0.1, training=False, block_size=7, gamma_scale=1.0, drop_with_noise=False): """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. This layer has been tested on a few training + runs with success, but needs further validation and possibly optimization for lower runtime impact. + """ + if drop_prob == 0. or not training: + return x _, _, height, width = x.shape total_size = width * height clipped_block_size = min(block_size, min(width, height)) @@ -60,14 +76,21 @@ class DropBlock2d(nn.Module): self.with_noise = with_noise def forward(self, x): - if not self.training or not self.drop_prob: - return x - return drop_block_2d(x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise) + return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise) + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. -def drop_path(x, drop_prob=0.): - """Drop paths (Stochastic Depth) per sample (when applied in residual blocks). """ + if drop_prob == 0. or not training: + return x keep_prob = 1 - drop_prob random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize @@ -76,13 +99,11 @@ def drop_path(x, drop_prob=0.): class DropPath(nn.ModuleDict): - """Drop paths (Stochastic Depth) per sample (when applied in residual blocks). + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): - if not self.training or not self.drop_prob: - return x - return drop_path(x, self.drop_prob) + return drop_path(x, self.drop_prob, self.training) diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index cb7e29ad..fcb26947 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -21,6 +21,11 @@ def _kernel_valid(k): class SelectiveKernelAttn(nn.Module): def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + """ Selective Kernel Attention Module + + Selective Kernel attention mechanism factored out into its own module. + + """ super(SelectiveKernelAttn, self).__init__() self.num_paths = num_paths self.pool = nn.AdaptiveAvgPool2d(1) @@ -48,8 +53,33 @@ class SelectiveKernelConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + """ Selective Kernel Convolution Module + + As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. + + Largest change is the input split, which divides the input channels across each convolution path, this can + be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps + the parameter count from ballooning when the convolutions themselves don't have groups, but still provides + a noteworthy increase in performance over similar param count models without this attention layer. -Ross W + + Args: + in_channels (int): module input (feature) channel count + out_channels (int): module output (feature) channel count + kernel_size (int, list): kernel size for each convolution branch + stride (int): stride for convolutions + dilation (int): dilation for module as a whole, impacts dilation of each branch + groups (int): number of groups for each branch + attn_reduction (int, float): reduction factor for attention features + min_attn_channels (int): minimum attention feature channels + keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations + split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, + can be viewed as grouping by path, output expands to module out_channels count + drop_block (nn.Module): drop block module + act_layer (nn.Module): activation layer to use + norm_layer (nn.Module): batchnorm/norm layer to use + """ super(SelectiveKernelConv, self).__init__() - kernel_size = kernel_size or [3, 5] + kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation _kernel_valid(kernel_size) if not isinstance(kernel_size, list): kernel_size = [kernel_size] * 2 diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 5b020272..456dc129 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -382,7 +382,7 @@ class ResNet(nn.Module): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks - dp = DropPath(drop_path_rate) if drop_block_rate else None + dp = DropPath(drop_path_rate) if drop_path_rate else None db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4 diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 7737bfed..d9657352 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -2,6 +2,10 @@ Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) +This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268) +and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer +to the original paper with some modifications of my own to better balance param count vs accuracy. + Hacked together by Ross Wightman """ import math @@ -29,7 +33,8 @@ def _cfg(url='', **kwargs): default_cfgs = { 'skresnet18': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'), - 'skresnet34': _cfg(url=''), + 'skresnet34': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'), 'skresnet50': _cfg(), 'skresnet50d': _cfg(), 'skresnext50_32x4d': _cfg(