Some cleanup and fixes for initial BlurPool impl. Still some testing and tweaks to go...

pull/101/head
Ross Wightman 5 years ago committed by Chris Ha
parent acd1b6cccd
commit 6cdeca24a3

@ -1,14 +1,17 @@
'''
"""
BlurPool layer inspired by
Kornia's Max_BlurPool2d
and
Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
- Kornia's Max_BlurPool2d
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
Hacked together by Chris Ha and Ross Wightman
"""
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .padding import get_padding
class BlurPool2d(nn.Module):
@ -25,30 +28,30 @@ class BlurPool2d(nn.Module):
Examples:
"""
def __init__(self, channels=None, blur_filter_size=3, stride=2) -> None:
def __init__(self, channels, blur_filter_size=3, stride=2) -> None:
super(BlurPool2d, self).__init__()
assert blur_filter_size in [3, 5]
assert blur_filter_size > 1
self.channels = channels
self.blur_filter_size = blur_filter_size
self.stride = stride
if blur_filter_size == 3:
pad_size = [1] * 4
blur_matrix = torch.Tensor([[1., 2., 1]]) / 4 # binomial filter b2
else:
pad_size = [2] * 4
blur_matrix = torch.Tensor([[1., 4., 6., 4., 1.]]) / 16 # binomial filter b4
pad_size = [get_padding(blur_filter_size, stride, dilation=1)] * 4
self.padding = nn.ReflectionPad2d(pad_size)
blur_filter = blur_matrix * blur_matrix.T
blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs
blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :])
# FIXME figure a clean hack to prevent the filter from getting saved in weights, but still
# plays nice with recursive module apply for fn like .cuda(), .type(), etc -RW
self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1)))
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
if not torch.is_tensor(input_tensor):
raise TypeError("Input input type is not a torch.Tensor. Got {}"
.format(type(input_tensor)))
raise TypeError("Input input type is not a torch.Tensor. Got {}".format(type(input_tensor)))
if not len(input_tensor.shape) == 4:
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}"
.format(input_tensor.shape))
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}".format(input_tensor.shape))
# apply blur_filter on input
return F.conv2d(self.padding(input_tensor), self.blur_filter, stride=self.stride, groups=input_tensor.shape[1])
return F.conv2d(
self.padding(input_tensor),
self.blur_filter.type(input_tensor.dtype),
stride=self.stride,
groups=input_tensor.shape[1])

@ -127,21 +127,14 @@ class BasicBlock(nn.Module):
first_planes = planes // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.blur = blur
if blur and stride==2:
self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=1, padding=first_dilation,
dilation=first_dilation, bias=False)
self.blurpool=BlurPool2d(channels=first_planes)
else:
self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=1 if blur else stride, padding=first_dilation,
dilation=first_dilation, bias=False)
self.blurpool = None
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
self.blurpool = BlurPool2d(channels=first_planes) if stride == 2 and blur else None
self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
self.bn2 = norm_layer(outplanes)
@ -165,11 +158,9 @@ class BasicBlock(nn.Module):
x = self.bn1(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act1(x)
if self.blurpool is not None:
x = self.act1(x)
x = self.blurpool(x)
else:
x = self.act1(x)
x = self.conv2(x)
x = self.bn2(x)
@ -209,19 +200,13 @@ class Bottleneck(nn.Module):
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
if blur and stride==2:
self.conv2 = nn.Conv2d(
first_planes, width, kernel_size=3, stride=1,
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.blurpool = BlurPool2d(channels=width)
else:
self.conv2 = nn.Conv2d(
first_planes, width, kernel_size=3, stride=stride,
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.blurpool = None
self.conv2 = nn.Conv2d(
first_planes, width, kernel_size=3, stride=1 if blur else stride,
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width)
self.act2 = act_layer(inplace=True)
self.blurpool = BlurPool2d(channels=width) if stride == 2 and blur else None
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
@ -251,6 +236,8 @@ class Bottleneck(nn.Module):
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act2(x)
if self.blurpool is not None:
x = self.blurpool(x)
x = self.conv3(x)
x = self.bn3(x)
@ -412,11 +399,12 @@ class ResNet(nn.Module):
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.act1 = act_layer(inplace=True)
# Stem Blur
# Stem Pooling
if 'max' in blur :
self.maxpool = nn.Sequential(*[
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
BlurPool2d(channels=self.inplanes)])
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
BlurPool2d(channels=self.inplanes, stride=2)
])
else :
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@ -470,8 +458,8 @@ class ResNet(nn.Module):
block_kwargs = dict(
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
dilation=dilation, **kwargs)
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, blur=self.blur, **block_kwargs)]
dilation=dilation, blur=self.blur, **kwargs)
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
self.inplanes = planes * block.expansion
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
@ -1075,7 +1063,7 @@ def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50 model. With assembled-cnn style blur
"""
default_cfg = default_cfgs['resnetblur18']
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='strided', **kwargs)
default_cfg = default_cfgs['resnetblur50']
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='max_strided', **kwargs)
model.default_cfg = default_cfg
return model
Loading…
Cancel
Save