You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
170 lines
6.7 KiB
170 lines
6.7 KiB
5 years ago
|
""" DropBlock, DropPath
|
||
|
|
||
|
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
|
||
|
|
||
5 years ago
|
Papers:
|
||
|
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
|
||
|
|
||
|
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
|
||
|
|
||
|
Code:
|
||
|
DropBlock impl inspired by two Tensorflow impl that I liked:
|
||
|
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
|
||
|
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
|
||
|
|
||
4 years ago
|
Hacked together by / Copyright 2020 Ross Wightman
|
||
5 years ago
|
"""
|
||
5 years ago
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
5 years ago
|
def drop_block_2d(
|
||
3 years ago
|
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
|
||
5 years ago
|
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||
5 years ago
|
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||
5 years ago
|
|
||
|
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
||
|
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
||
5 years ago
|
"""
|
||
5 years ago
|
B, C, H, W = x.shape
|
||
|
total_size = W * H
|
||
|
clipped_block_size = min(block_size, min(W, H))
|
||
5 years ago
|
# seed_drop_rate, the gamma parameter
|
||
5 years ago
|
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||
3 years ago
|
(W - block_size + 1) * (H - block_size + 1))
|
||
5 years ago
|
|
||
|
# Forces the block to be inside the feature map.
|
||
5 years ago
|
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
|
||
|
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
|
||
|
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
|
||
|
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
|
||
|
|
||
|
if batchwise:
|
||
|
# one mask for whole batch, quite a bit faster
|
||
|
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
|
||
|
else:
|
||
|
uniform_noise = torch.rand_like(x)
|
||
|
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
|
||
5 years ago
|
block_mask = -F.max_pool2d(
|
||
|
-block_mask,
|
||
5 years ago
|
kernel_size=clipped_block_size, # block_size,
|
||
5 years ago
|
stride=1,
|
||
|
padding=clipped_block_size // 2)
|
||
|
|
||
5 years ago
|
if with_noise:
|
||
|
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
||
|
if inplace:
|
||
|
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
|
||
|
else:
|
||
|
x = x * block_mask + normal_noise * (1 - block_mask)
|
||
|
else:
|
||
|
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
|
||
|
if inplace:
|
||
|
x.mul_(block_mask * normalize_scale)
|
||
|
else:
|
||
|
x = x * block_mask * normalize_scale
|
||
|
return x
|
||
|
|
||
|
|
||
|
def drop_block_fast_2d(
|
||
|
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
|
||
3 years ago
|
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False):
|
||
5 years ago
|
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||
|
|
||
|
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
|
||
|
block mask at edges.
|
||
|
"""
|
||
|
B, C, H, W = x.shape
|
||
|
total_size = W * H
|
||
|
clipped_block_size = min(block_size, min(W, H))
|
||
|
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||
|
(W - block_size + 1) * (H - block_size + 1))
|
||
|
|
||
3 years ago
|
block_mask = torch.empty_like(x).bernoulli_(gamma)
|
||
5 years ago
|
block_mask = F.max_pool2d(
|
||
|
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
|
||
|
|
||
|
if with_noise:
|
||
3 years ago
|
normal_noise = torch.empty_like(x).normal_()
|
||
5 years ago
|
if inplace:
|
||
|
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
|
||
|
else:
|
||
|
x = x * (1. - block_mask) + normal_noise * block_mask
|
||
5 years ago
|
else:
|
||
5 years ago
|
block_mask = 1 - block_mask
|
||
3 years ago
|
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
|
||
5 years ago
|
if inplace:
|
||
|
x.mul_(block_mask * normalize_scale)
|
||
|
else:
|
||
|
x = x * block_mask * normalize_scale
|
||
5 years ago
|
return x
|
||
|
|
||
|
|
||
|
class DropBlock2d(nn.Module):
|
||
|
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||
|
"""
|
||
3 years ago
|
|
||
|
def __init__(
|
||
|
self,
|
||
3 years ago
|
drop_prob: float = 0.1,
|
||
|
block_size: int = 7,
|
||
|
gamma_scale: float = 1.0,
|
||
|
with_noise: bool = False,
|
||
|
inplace: bool = False,
|
||
|
batchwise: bool = False,
|
||
|
fast: bool = True):
|
||
5 years ago
|
super(DropBlock2d, self).__init__()
|
||
|
self.drop_prob = drop_prob
|
||
|
self.gamma_scale = gamma_scale
|
||
|
self.block_size = block_size
|
||
|
self.with_noise = with_noise
|
||
5 years ago
|
self.inplace = inplace
|
||
|
self.batchwise = batchwise
|
||
|
self.fast = fast # FIXME finish comparisons of fast vs not
|
||
5 years ago
|
|
||
|
def forward(self, x):
|
||
5 years ago
|
if not self.training or not self.drop_prob:
|
||
|
return x
|
||
|
if self.fast:
|
||
|
return drop_block_fast_2d(
|
||
3 years ago
|
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace)
|
||
5 years ago
|
else:
|
||
|
return drop_block_2d(
|
||
|
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
||
5 years ago
|
|
||
|
|
||
3 years ago
|
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
||
5 years ago
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||
5 years ago
|
|
||
5 years ago
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||
|
'survival rate' as the argument.
|
||
5 years ago
|
|
||
5 years ago
|
"""
|
||
5 years ago
|
if drop_prob == 0. or not training:
|
||
|
return x
|
||
5 years ago
|
keep_prob = 1 - drop_prob
|
||
4 years ago
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||
3 years ago
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||
|
if keep_prob > 0.0 and scale_by_keep:
|
||
|
random_tensor.div_(keep_prob)
|
||
|
return x * random_tensor
|
||
5 years ago
|
|
||
|
|
||
4 years ago
|
class DropPath(nn.Module):
|
||
5 years ago
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||
5 years ago
|
"""
|
||
3 years ago
|
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
||
5 years ago
|
super(DropPath, self).__init__()
|
||
|
self.drop_prob = drop_prob
|
||
3 years ago
|
self.scale_by_keep = scale_by_keep
|
||
5 years ago
|
|
||
|
def forward(self, x):
|
||
3 years ago
|
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
||
2 years ago
|
|
||
|
def extra_repr(self):
|
||
|
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|