diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index 6de9e3f7..90c1933a 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -69,7 +69,7 @@ def drop_block_2d( def drop_block_fast_2d( x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, - gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False): """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid @@ -81,24 +81,19 @@ def drop_block_fast_2d( gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( (W - block_size + 1) * (H - block_size + 1)) - if batchwise: - # one mask for whole batch, quite a bit faster - block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma - else: - # mask per batch element - block_mask = torch.rand_like(x) < gamma + block_mask = torch.empty_like(x).bernoulli_(gamma) block_mask = F.max_pool2d( block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) if with_noise: - normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + normal_noise = torch.empty_like(x).normal_() if inplace: x.mul_(1. - block_mask).add_(normal_noise * block_mask) else: x = x * (1. - block_mask) + normal_noise * block_mask else: block_mask = 1 - block_mask - normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype) if inplace: x.mul_(block_mask * normalize_scale) else: @@ -131,13 +126,13 @@ class DropBlock2d(nn.Module): return x if self.fast: return drop_block_fast_2d( - x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace) else: return drop_block_2d( x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) -def drop_path(x, drop_prob: float = 0., training: bool = False): +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, @@ -151,18 +146,19 @@ def drop_path(x, drop_prob: float = 0., training: bool = False): return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ - def __init__(self, drop_prob=None): + def __init__(self, drop_prob=None, scale_by_keep=True): super(DropPath, self).__init__() self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep def forward(self, x): - return drop_path(x, self.drop_prob, self.training) + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)