Improve dropblock impl, add fast variant, and better AMP speed, inplace, batchwise... few ResNeSt cleanups

pull/148/head
Ross Wightman 4 years ago
parent 63addb741f
commit 1904ed8fec

@ -22,44 +22,89 @@ import math
def drop_block_2d( def drop_block_2d(
x, drop_prob: float = 0.1, training: bool = False, block_size: int = 7, x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
gamma_scale: float = 1.0, drop_with_noise: bool = False): with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
runs with success, but needs further validation and possibly optimization for lower runtime impact. runs with success, but needs further validation and possibly optimization for lower runtime impact.
""" """
if drop_prob == 0. or not training: B, C, H, W = x.shape
return x total_size = W * H
_, _, height, width = x.shape clipped_block_size = min(block_size, min(W, H))
total_size = width * height
clipped_block_size = min(block_size, min(width, height))
# seed_drop_rate, the gamma parameter # seed_drop_rate, the gamma parameter
seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(width - block_size + 1) * (W - block_size + 1) * (H - block_size + 1))
(height - block_size + 1))
# Forces the block to be inside the feature map. # Forces the block to be inside the feature map.
w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device)) w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \ valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2)) ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
valid_block = torch.reshape(valid_block, (1, 1, height, width)).float() valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
uniform_noise = torch.rand_like(x, dtype=torch.float32) if batchwise:
block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float() # one mask for whole batch, quite a bit faster
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
else:
uniform_noise = torch.rand_like(x)
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
block_mask = -F.max_pool2d( block_mask = -F.max_pool2d(
-block_mask, -block_mask,
kernel_size=clipped_block_size, # block_size, ??? kernel_size=clipped_block_size, # block_size,
stride=1, stride=1,
padding=clipped_block_size // 2) padding=clipped_block_size // 2)
if drop_with_noise: if with_noise:
normal_noise = torch.randn_like(x) normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
x = x * block_mask + normal_noise * (1 - block_mask) if inplace:
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
else:
x = x * block_mask + normal_noise * (1 - block_mask)
else:
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
if inplace:
x.mul_(block_mask * normalize_scale)
else:
x = x * block_mask * normalize_scale
return x
def drop_block_fast_2d(
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
block mask at edges.
"""
B, C, H, W = x.shape
total_size = W * H
clipped_block_size = min(block_size, min(W, H))
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(W - block_size + 1) * (H - block_size + 1))
if batchwise:
# one mask for whole batch, quite a bit faster
block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
else:
# mask per batch element
block_mask = torch.rand_like(x) < gamma
block_mask = F.max_pool2d(
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
if with_noise:
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
if inplace:
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
else:
x = x * (1. - block_mask) + normal_noise * block_mask
else: else:
normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7) block_mask = 1 - block_mask
x = x * block_mask * normalize_scale normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
if inplace:
x.mul_(block_mask * normalize_scale)
else:
x = x * block_mask * normalize_scale
return x return x
@ -70,15 +115,28 @@ class DropBlock2d(nn.Module):
drop_prob=0.1, drop_prob=0.1,
block_size=7, block_size=7,
gamma_scale=1.0, gamma_scale=1.0,
with_noise=False): with_noise=False,
inplace=False,
batchwise=False,
fast=True):
super(DropBlock2d, self).__init__() super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
self.gamma_scale = gamma_scale self.gamma_scale = gamma_scale
self.block_size = block_size self.block_size = block_size
self.with_noise = with_noise self.with_noise = with_noise
self.inplace = inplace
self.batchwise = batchwise
self.fast = fast # FIXME finish comparisons of fast vs not
def forward(self, x): def forward(self, x):
return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise) if not self.training or not self.drop_prob:
return x
if self.fast:
return drop_block_fast_2d(
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
else:
return drop_block_2d(
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
def drop_path(x, drop_prob: float = 0., training: bool = False): def drop_path(x, drop_prob: float = 0., training: bool = False):

@ -31,25 +31,24 @@ class RadixSoftmax(nn.Module):
class SplitAttnConv2d(nn.Module): class SplitAttnConv2d(nn.Module):
"""Split-Attention Conv2d """Split-Attention Conv2d
""" """
def __init__(self, in_channels, channels, kernel_size, stride=1, padding=0, def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
super(SplitAttnConv2d, self).__init__() super(SplitAttnConv2d, self).__init__()
self.radix = radix self.radix = radix
self.cardinality = groups self.drop_block = drop_block
self.channels = channels mid_chs = out_channels * radix
mid_chs = channels * radix
attn_chs = max(in_channels * radix // reduction_factor, 32) attn_chs = max(in_channels * radix // reduction_factor, 32)
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels, mid_chs, kernel_size, stride, padding, dilation, in_channels, mid_chs, kernel_size, stride, padding, dilation,
groups=groups * radix, bias=bias, **kwargs) groups=groups * radix, bias=bias, **kwargs)
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
self.act0 = act_layer(inplace=True) self.act0 = act_layer(inplace=True)
self.fc1 = nn.Conv2d(channels, attn_chs, 1, groups=self.cardinality) self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=self.cardinality) self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
self.drop_block = drop_block
self.rsoftmax = RadixSoftmax(radix, groups) self.rsoftmax = RadixSoftmax(radix, groups)
def forward(self, x): def forward(self, x):
@ -63,7 +62,7 @@ class SplitAttnConv2d(nn.Module):
B, RC, H, W = x.shape B, RC, H, W = x.shape
if self.radix > 1: if self.radix > 1:
x = x.reshape((B, self.radix, RC // self.radix, H, W)) x = x.reshape((B, self.radix, RC // self.radix, H, W))
x_gap = torch.sum(x, dim=1) x_gap = x.sum(dim=1)
else: else:
x_gap = x x_gap = x
x_gap = F.adaptive_avg_pool2d(x_gap, 1) x_gap = F.adaptive_avg_pool2d(x_gap, 1)

@ -76,10 +76,10 @@ class ResNestBottleneck(nn.Module):
else: else:
avd_stride = 0 avd_stride = 0
self.radix = radix self.radix = radix
self.drop_block = drop_block
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
self.bn1 = norm_layer(group_width) self.bn1 = norm_layer(group_width)
self.drop_block1 = drop_block if drop_block is not None else None
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
@ -88,20 +88,17 @@ class ResNestBottleneck(nn.Module):
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness
self.drop_block2 = None
self.act2 = None self.act2 = None
else: else:
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, bias=False) dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(group_width) self.bn2 = norm_layer(group_width)
self.drop_block2 = drop_block if drop_block is not None else None
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None
self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes*4) self.bn3 = norm_layer(planes*4)
self.drop_block3 = drop_block if drop_block is not None else None
self.act3 = act_layer(inplace=True) self.act3 = act_layer(inplace=True)
self.downsample = downsample self.downsample = downsample
@ -113,8 +110,8 @@ class ResNestBottleneck(nn.Module):
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
if self.drop_block1 is not None: if self.drop_block is not None:
out = self.drop_block1(out) out = self.drop_block(out)
out = self.act1(out) out = self.act1(out)
if self.avd_first is not None: if self.avd_first is not None:
@ -123,8 +120,8 @@ class ResNestBottleneck(nn.Module):
out = self.conv2(out) out = self.conv2(out)
if self.bn2 is not None: if self.bn2 is not None:
out = self.bn2(out) out = self.bn2(out)
if self.drop_block2 is not None: if self.drop_block is not None:
out = self.drop_block2(out) out = self.drop_block(out)
out = self.act2(out) out = self.act2(out)
if self.avd_last is not None: if self.avd_last is not None:
@ -132,8 +129,8 @@ class ResNestBottleneck(nn.Module):
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
if self.drop_block3 is not None: if self.drop_block is not None:
out = self.drop_block3(out) out = self.drop_block(out)
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(x) residual = self.downsample(x)

Loading…
Cancel
Save