Final blurpool2d cleanup and add resnetblur50 weights, match tresnet Downsample arg order to BlurPool2d for interop

pull/136/head
Ross Wightman 5 years ago
parent 9590f301a9
commit 2681a8d618

@ -31,9 +31,9 @@ def load_state_dict(checkpoint_path, use_ema=False):
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False):
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=strict)
def resume_checkpoint(model, checkpoint_path):

@ -18,4 +18,4 @@ from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .anti_aliasing import AntiAliasDownsampleLayer
from .space_to_depth import SpaceToDepthModule
from .blurpool import BlurPool2d
from .blur_pool import BlurPool2d

@ -5,12 +5,12 @@ import torch.nn.functional as F
class AntiAliasDownsampleLayer(nn.Module):
def __init__(self, no_jit: bool = False, filt_size: int = 3, stride: int = 2, channels: int = 0):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
super(AntiAliasDownsampleLayer, self).__init__()
if no_jit:
self.op = Downsample(filt_size, stride, channels)
self.op = Downsample(channels, filt_size, stride)
else:
self.op = DownsampleJIT(filt_size, stride, channels)
self.op = DownsampleJIT(channels, filt_size, stride)
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
@ -20,10 +20,10 @@ class AntiAliasDownsampleLayer(nn.Module):
@torch.jit.script
class DownsampleJIT(object):
def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
self.channels = channels
self.stride = stride
self.filt_size = filt_size
self.channels = channels
assert self.filt_size == 3
assert stride == 2
self.filt = {} # lazy init by device for DataParallel compat
@ -32,8 +32,7 @@ class DownsampleJIT(object):
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt)
filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
return filt
return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
def __call__(self, input: torch.Tensor):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
@ -42,11 +41,11 @@ class DownsampleJIT(object):
class Downsample(nn.Module):
def __init__(self, filt_size=3, stride=2, channels=None):
def __init__(self, channels=None, filt_size=3, stride=2):
super(Downsample, self).__init__()
self.channels = channels
self.filt_size = filt_size
self.stride = stride
self.channels = channels
assert self.filt_size == 3
filt = torch.tensor([1., 2., 1.])

@ -0,0 +1,58 @@
"""
BlurPool layer inspired by
- Kornia's Max_BlurPool2d
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
FIXME merge this impl with those in `anti_aliasing.py`
Hacked together by Chris Ha and Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict
from .padding import get_padding
class BlurPool2d(nn.Module):
r"""Creates a module that computes blurs and downsample a given feature map.
See :cite:`zhang2019shiftinvar` for more details.
Corresponds to the Downsample class, which does blurring and subsampling
Args:
channels = Number of input channels
filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
stride (int): downsampling filter stride
Returns:
torch.Tensor: the transformed tensor.
"""
filt: Dict[str, torch.Tensor]
def __init__(self, channels, filt_size=3, stride=2) -> None:
super(BlurPool2d, self).__init__()
assert filt_size > 1
self.channels = channels
self.filt_size = filt_size
self.stride = stride
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4
self.padding = nn.ReflectionPad2d(pad_size)
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat
self.filt = {} # lazy init by device for DataParallel compat
def _create_filter(self, like: torch.Tensor):
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device)
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1)
def _apply(self, fn):
# override nn.Module _apply, reset filter cache if used
self.filt = {}
super(BlurPool2d, self)._apply(fn)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
C = input_tensor.shape[1]
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
return F.conv2d(
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)

@ -1,55 +0,0 @@
"""
BlurPool layer inspired by
- Kornia's Max_BlurPool2d
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
Hacked together by Chris Ha and Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .padding import get_padding
class BlurPool2d(nn.Module):
r"""Creates a module that computes blurs and downsample a given feature map.
See :cite:`zhang2019shiftinvar` for more details.
Corresponds to the Downsample class, which does blurring and subsampling
Args:
channels = Number of input channels
blur_filter_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
stride (int): downsampling filter stride
Shape:
Returns:
torch.Tensor: the transformed tensor.
Examples:
"""
def __init__(self, channels, blur_filter_size=3, stride=2) -> None:
super(BlurPool2d, self).__init__()
assert blur_filter_size > 1
self.channels = channels
self.blur_filter_size = blur_filter_size
self.stride = stride
pad_size = [get_padding(blur_filter_size, stride, dilation=1)] * 4
self.padding = nn.ReflectionPad2d(pad_size)
blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs
blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :])
self.blur_filter = blur_filter[None, None, :, :]
def _apply(self, fn):
# override nn.Module _apply to prevent need for blur_filter to be registered as a buffer,
# this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected
super(BlurPool2d, self)._apply(fn)
self.blur_filter = fn(self.blur_filter)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
C = input_tensor.shape[1]
return F.conv2d(
self.padding(input_tensor),
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), stride=self.stride, groups=C)

@ -118,8 +118,11 @@ default_cfgs = {
'ecaresnet101d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
interpolation='bicubic'),
'resnetblur18': _cfg(),
'resnetblur50': _cfg()
'resnetblur18': _cfg(
interpolation='bicubic'),
'resnetblur50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth',
interpolation='bicubic')
}
@ -133,7 +136,7 @@ class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, drop_block=None, drop_path=None, blur=False):
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
@ -141,13 +144,14 @@ class BasicBlock(nn.Module):
first_planes = planes // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
use_aa = aa_layer is not None
self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=1 if blur else stride, padding=first_dilation,
inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
dilation=first_dilation, bias=False)
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
self.blurpool = BlurPool2d(channels=first_planes) if stride == 2 and blur else None
self.aa = aa_layer(channels=first_planes) if stride == 2 and use_aa else None
self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
@ -173,8 +177,8 @@ class BasicBlock(nn.Module):
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act1(x)
if self.blurpool is not None:
x = self.blurpool(x)
if self.aa is not None:
x = self.aa(x)
x = self.conv2(x)
x = self.bn2(x)
@ -201,25 +205,25 @@ class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, drop_block=None, drop_path=None, blur=False):
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.blur = blur
use_aa = aa_layer is not None
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
self.conv2 = nn.Conv2d(
first_planes, width, kernel_size=3, stride=1 if blur else stride,
first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width)
self.act2 = act_layer(inplace=True)
self.blurpool = BlurPool2d(channels=width) if stride == 2 and blur else None
self.aa = aa_layer(channels=width) if stride == 2 and use_aa else None
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
@ -250,8 +254,8 @@ class Bottleneck(nn.Module):
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act2(x)
if self.blurpool is not None:
x = self.blurpool(x)
if self.aa is not None:
x = self.aa(x)
x = self.conv3(x)
x = self.bn3(x)
@ -365,25 +369,19 @@ class ResNet(nn.Module):
Whether to use average pooling for projection skip connection between stages/downsample.
output_stride : int, default 32
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
act_layer : class, activation layer
norm_layer : class, normalization layer
act_layer : nn.Module, activation layer
norm_layer : nn.Module, normalization layer
aa_layer : nn.Module, anti-aliasing layer
drop_rate : float, default 0.
Dropout probability before classifier, for training
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
blur : str, default ''
Location of Blurring:
* '', default - Not applied
* 'max' - only stem layer MaxPool will be blurred
* 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style)
* 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets)
"""
def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='',
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', blur='', zero_init_last_bn=True, block_args=None):
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
block_args = block_args or dict()
self.num_classes = num_classes
deep_stem = 'deep' in stem_type
@ -392,7 +390,6 @@ class ResNet(nn.Module):
self.base_width = base_width
self.drop_rate = drop_rate
self.expansion = block.expansion
self.blur = 'strided' in blur
super(ResNet, self).__init__()
# Stem
@ -414,10 +411,10 @@ class ResNet(nn.Module):
self.bn1 = norm_layer(self.inplanes)
self.act1 = act_layer(inplace=True)
# Stem Pooling
if 'max' in blur :
if aa_layer is not None:
self.maxpool = nn.Sequential(*[
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
BlurPool2d(channels=self.inplanes, stride=2)
aa_layer(channels=self.inplanes, stride=2)
])
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@ -437,7 +434,7 @@ class ResNet(nn.Module):
assert output_stride == 32
layer_args = list(zip(channels, layers, strides, dilations))
layer_kwargs = dict(
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
@ -472,7 +469,7 @@ class ResNet(nn.Module):
block_kwargs = dict(
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
dilation=dilation, blur=self.blur, **kwargs)
dilation=dilation, **kwargs)
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
self.inplanes = planes * block.expansion
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
@ -1148,18 +1145,21 @@ def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-18 model with blur anti-aliasing
"""
default_cfg = default_cfgs['resnetblur18']
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, blur='max_strided',**kwargs)
model = ResNet(
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50 model with blur anti-aliasing
"""
default_cfg = default_cfgs['resnetblur50']
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='max_strided', **kwargs)
model = ResNet(
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

Loading…
Cancel
Save