From 6cdeca24a389cce7693e043b89db04e230b3a405 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 1 Apr 2020 12:01:59 -0700 Subject: [PATCH] Some cleanup and fixes for initial BlurPool impl. Still some testing and tweaks to go... --- timm/models/layers/blurpool.py | 43 +++++++++++++++------------- timm/models/resnet.py | 52 +++++++++++++--------------------- 2 files changed, 43 insertions(+), 52 deletions(-) diff --git a/timm/models/layers/blurpool.py b/timm/models/layers/blurpool.py index 0b37a90c..a12274d8 100644 --- a/timm/models/layers/blurpool.py +++ b/timm/models/layers/blurpool.py @@ -1,14 +1,17 @@ -''' +""" BlurPool layer inspired by -Kornia's Max_BlurPool2d -and -Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + - Kornia's Max_BlurPool2d + - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + +Hacked together by Chris Ha and Ross Wightman +""" -''' import torch import torch.nn as nn import torch.nn.functional as F +import numpy as np +from .padding import get_padding class BlurPool2d(nn.Module): @@ -25,30 +28,30 @@ class BlurPool2d(nn.Module): Examples: """ - def __init__(self, channels=None, blur_filter_size=3, stride=2) -> None: + def __init__(self, channels, blur_filter_size=3, stride=2) -> None: super(BlurPool2d, self).__init__() - assert blur_filter_size in [3, 5] + assert blur_filter_size > 1 self.channels = channels self.blur_filter_size = blur_filter_size self.stride = stride - if blur_filter_size == 3: - pad_size = [1] * 4 - blur_matrix = torch.Tensor([[1., 2., 1]]) / 4 # binomial filter b2 - else: - pad_size = [2] * 4 - blur_matrix = torch.Tensor([[1., 4., 6., 4., 1.]]) / 16 # binomial filter b4 - + pad_size = [get_padding(blur_filter_size, stride, dilation=1)] * 4 self.padding = nn.ReflectionPad2d(pad_size) - blur_filter = blur_matrix * blur_matrix.T + + blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs + blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :]) + # FIXME figure a clean hack to prevent the filter from getting saved in weights, but still + # plays nice with recursive module apply for fn like .cuda(), .type(), etc -RW self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1))) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore if not torch.is_tensor(input_tensor): - raise TypeError("Input input type is not a torch.Tensor. Got {}" - .format(type(input_tensor))) + raise TypeError("Input input type is not a torch.Tensor. Got {}".format(type(input_tensor))) if not len(input_tensor.shape) == 4: - raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}" - .format(input_tensor.shape)) + raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}".format(input_tensor.shape)) # apply blur_filter on input - return F.conv2d(self.padding(input_tensor), self.blur_filter, stride=self.stride, groups=input_tensor.shape[1]) \ No newline at end of file + return F.conv2d( + self.padding(input_tensor), + self.blur_filter.type(input_tensor.dtype), + stride=self.stride, + groups=input_tensor.shape[1]) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 057eca6c..fdb8097c 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -127,21 +127,14 @@ class BasicBlock(nn.Module): first_planes = planes // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation - self.blur = blur - if blur and stride==2: - self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=1, padding=first_dilation, - dilation=first_dilation, bias=False) - self.blurpool=BlurPool2d(channels=first_planes) - else: - self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=1 if blur else stride, padding=first_dilation, dilation=first_dilation, bias=False) - self.blurpool = None - self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) + self.blurpool = BlurPool2d(channels=first_planes) if stride == 2 and blur else None + self.conv2 = nn.Conv2d( first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) self.bn2 = norm_layer(outplanes) @@ -165,11 +158,9 @@ class BasicBlock(nn.Module): x = self.bn1(x) if self.drop_block is not None: x = self.drop_block(x) + x = self.act1(x) if self.blurpool is not None: - x = self.act1(x) x = self.blurpool(x) - else: - x = self.act1(x) x = self.conv2(x) x = self.bn2(x) @@ -209,19 +200,13 @@ class Bottleneck(nn.Module): self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) - if blur and stride==2: - self.conv2 = nn.Conv2d( - first_planes, width, kernel_size=3, stride=1, - padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) - self.blurpool = BlurPool2d(channels=width) - else: - self.conv2 = nn.Conv2d( - first_planes, width, kernel_size=3, stride=stride, - padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) - self.blurpool = None - + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=1 if blur else stride, + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) + self.blurpool = BlurPool2d(channels=width) if stride == 2 and blur else None + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) @@ -251,6 +236,8 @@ class Bottleneck(nn.Module): if self.drop_block is not None: x = self.drop_block(x) x = self.act2(x) + if self.blurpool is not None: + x = self.blurpool(x) x = self.conv3(x) x = self.bn3(x) @@ -412,11 +399,12 @@ class ResNet(nn.Module): self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.act1 = act_layer(inplace=True) - # Stem Blur + # Stem Pooling if 'max' in blur : self.maxpool = nn.Sequential(*[ - nn.MaxPool2d(kernel_size=3, stride=1, padding=1), - BlurPool2d(channels=self.inplanes)]) + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + BlurPool2d(channels=self.inplanes, stride=2) + ]) else : self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -470,8 +458,8 @@ class ResNet(nn.Module): block_kwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - dilation=dilation, **kwargs) - layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, blur=self.blur, **block_kwargs)] + dilation=dilation, blur=self.blur, **kwargs) + layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)] self.inplanes = planes * block.expansion layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)] @@ -1075,7 +1063,7 @@ def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-50 model. With assembled-cnn style blur """ - default_cfg = default_cfgs['resnetblur18'] - model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='strided', **kwargs) + default_cfg = default_cfgs['resnetblur50'] + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='max_strided', **kwargs) model.default_cfg = default_cfg return model \ No newline at end of file