Final blurpool2d cleanup and add resnetblur50 weights, match tresnet Downsample arg order to BlurPool2d for interop

pull/136/head
Ross Wightman 5 years ago
parent 9590f301a9
commit 2681a8d618

@ -31,9 +31,9 @@ def load_state_dict(checkpoint_path, use_ema=False):
raise FileNotFoundError() raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False): def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
state_dict = load_state_dict(checkpoint_path, use_ema) state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict) model.load_state_dict(state_dict, strict=strict)
def resume_checkpoint(model, checkpoint_path): def resume_checkpoint(model, checkpoint_path):

@ -18,4 +18,4 @@ from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .anti_aliasing import AntiAliasDownsampleLayer from .anti_aliasing import AntiAliasDownsampleLayer
from .space_to_depth import SpaceToDepthModule from .space_to_depth import SpaceToDepthModule
from .blurpool import BlurPool2d from .blur_pool import BlurPool2d

@ -5,12 +5,12 @@ import torch.nn.functional as F
class AntiAliasDownsampleLayer(nn.Module): class AntiAliasDownsampleLayer(nn.Module):
def __init__(self, no_jit: bool = False, filt_size: int = 3, stride: int = 2, channels: int = 0): def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
super(AntiAliasDownsampleLayer, self).__init__() super(AntiAliasDownsampleLayer, self).__init__()
if no_jit: if no_jit:
self.op = Downsample(filt_size, stride, channels) self.op = Downsample(channels, filt_size, stride)
else: else:
self.op = DownsampleJIT(filt_size, stride, channels) self.op = DownsampleJIT(channels, filt_size, stride)
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
@ -20,10 +20,10 @@ class AntiAliasDownsampleLayer(nn.Module):
@torch.jit.script @torch.jit.script
class DownsampleJIT(object): class DownsampleJIT(object):
def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0): def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
self.channels = channels
self.stride = stride self.stride = stride
self.filt_size = filt_size self.filt_size = filt_size
self.channels = channels
assert self.filt_size == 3 assert self.filt_size == 3
assert stride == 2 assert stride == 2
self.filt = {} # lazy init by device for DataParallel compat self.filt = {} # lazy init by device for DataParallel compat
@ -32,8 +32,7 @@ class DownsampleJIT(object):
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
filt = filt[:, None] * filt[None, :] filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt) filt = filt / torch.sum(filt)
filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
return filt
def __call__(self, input: torch.Tensor): def __call__(self, input: torch.Tensor):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
@ -42,11 +41,11 @@ class DownsampleJIT(object):
class Downsample(nn.Module): class Downsample(nn.Module):
def __init__(self, filt_size=3, stride=2, channels=None): def __init__(self, channels=None, filt_size=3, stride=2):
super(Downsample, self).__init__() super(Downsample, self).__init__()
self.channels = channels
self.filt_size = filt_size self.filt_size = filt_size
self.stride = stride self.stride = stride
self.channels = channels
assert self.filt_size == 3 assert self.filt_size == 3
filt = torch.tensor([1., 2., 1.]) filt = torch.tensor([1., 2., 1.])

@ -0,0 +1,58 @@
"""
BlurPool layer inspired by
- Kornia's Max_BlurPool2d
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
FIXME merge this impl with those in `anti_aliasing.py`
Hacked together by Chris Ha and Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict
from .padding import get_padding
class BlurPool2d(nn.Module):
r"""Creates a module that computes blurs and downsample a given feature map.
See :cite:`zhang2019shiftinvar` for more details.
Corresponds to the Downsample class, which does blurring and subsampling
Args:
channels = Number of input channels
filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
stride (int): downsampling filter stride
Returns:
torch.Tensor: the transformed tensor.
"""
filt: Dict[str, torch.Tensor]
def __init__(self, channels, filt_size=3, stride=2) -> None:
super(BlurPool2d, self).__init__()
assert filt_size > 1
self.channels = channels
self.filt_size = filt_size
self.stride = stride
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4
self.padding = nn.ReflectionPad2d(pad_size)
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat
self.filt = {} # lazy init by device for DataParallel compat
def _create_filter(self, like: torch.Tensor):
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device)
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1)
def _apply(self, fn):
# override nn.Module _apply, reset filter cache if used
self.filt = {}
super(BlurPool2d, self)._apply(fn)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
C = input_tensor.shape[1]
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
return F.conv2d(
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)

@ -1,55 +0,0 @@
"""
BlurPool layer inspired by
- Kornia's Max_BlurPool2d
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
Hacked together by Chris Ha and Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .padding import get_padding
class BlurPool2d(nn.Module):
r"""Creates a module that computes blurs and downsample a given feature map.
See :cite:`zhang2019shiftinvar` for more details.
Corresponds to the Downsample class, which does blurring and subsampling
Args:
channels = Number of input channels
blur_filter_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
stride (int): downsampling filter stride
Shape:
Returns:
torch.Tensor: the transformed tensor.
Examples:
"""
def __init__(self, channels, blur_filter_size=3, stride=2) -> None:
super(BlurPool2d, self).__init__()
assert blur_filter_size > 1
self.channels = channels
self.blur_filter_size = blur_filter_size
self.stride = stride
pad_size = [get_padding(blur_filter_size, stride, dilation=1)] * 4
self.padding = nn.ReflectionPad2d(pad_size)
blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs
blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :])
self.blur_filter = blur_filter[None, None, :, :]
def _apply(self, fn):
# override nn.Module _apply to prevent need for blur_filter to be registered as a buffer,
# this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected
super(BlurPool2d, self)._apply(fn)
self.blur_filter = fn(self.blur_filter)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
C = input_tensor.shape[1]
return F.conv2d(
self.padding(input_tensor),
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), stride=self.stride, groups=C)

@ -118,8 +118,11 @@ default_cfgs = {
'ecaresnet101d_pruned': _cfg( 'ecaresnet101d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
interpolation='bicubic'), interpolation='bicubic'),
'resnetblur18': _cfg(), 'resnetblur18': _cfg(
'resnetblur50': _cfg() interpolation='bicubic'),
'resnetblur50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth',
interpolation='bicubic')
} }
@ -133,7 +136,7 @@ class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, drop_block=None, drop_path=None, blur=False): attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
@ -141,13 +144,14 @@ class BasicBlock(nn.Module):
first_planes = planes // reduce_first first_planes = planes // reduce_first
outplanes = planes * self.expansion outplanes = planes * self.expansion
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
use_aa = aa_layer is not None
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=1 if blur else stride, padding=first_dilation, inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
dilation=first_dilation, bias=False) dilation=first_dilation, bias=False)
self.bn1 = norm_layer(first_planes) self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.blurpool = BlurPool2d(channels=first_planes) if stride == 2 and blur else None self.aa = aa_layer(channels=first_planes) if stride == 2 and use_aa else None
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
@ -173,8 +177,8 @@ class BasicBlock(nn.Module):
if self.drop_block is not None: if self.drop_block is not None:
x = self.drop_block(x) x = self.drop_block(x)
x = self.act1(x) x = self.act1(x)
if self.blurpool is not None: if self.aa is not None:
x = self.blurpool(x) x = self.aa(x)
x = self.conv2(x) x = self.conv2(x)
x = self.bn2(x) x = self.bn2(x)
@ -201,25 +205,25 @@ class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, drop_block=None, drop_path=None, blur=False): attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality) width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first first_planes = width // reduce_first
outplanes = planes * self.expansion outplanes = planes * self.expansion
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
self.blur = blur use_aa = aa_layer is not None
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(first_planes) self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
first_planes, width, kernel_size=3, stride=1 if blur else stride, first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width) self.bn2 = norm_layer(width)
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.blurpool = BlurPool2d(channels=width) if stride == 2 and blur else None self.aa = aa_layer(channels=width) if stride == 2 and use_aa else None
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes) self.bn3 = norm_layer(outplanes)
@ -250,8 +254,8 @@ class Bottleneck(nn.Module):
if self.drop_block is not None: if self.drop_block is not None:
x = self.drop_block(x) x = self.drop_block(x)
x = self.act2(x) x = self.act2(x)
if self.blurpool is not None: if self.aa is not None:
x = self.blurpool(x) x = self.aa(x)
x = self.conv3(x) x = self.conv3(x)
x = self.bn3(x) x = self.bn3(x)
@ -365,25 +369,19 @@ class ResNet(nn.Module):
Whether to use average pooling for projection skip connection between stages/downsample. Whether to use average pooling for projection skip connection between stages/downsample.
output_stride : int, default 32 output_stride : int, default 32
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation. Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
act_layer : class, activation layer act_layer : nn.Module, activation layer
norm_layer : class, normalization layer norm_layer : nn.Module, normalization layer
aa_layer : nn.Module, anti-aliasing layer
drop_rate : float, default 0. drop_rate : float, default 0.
Dropout probability before classifier, for training Dropout probability before classifier, for training
global_pool : str, default 'avg' global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
blur : str, default ''
Location of Blurring:
* '', default - Not applied
* 'max' - only stem layer MaxPool will be blurred
* 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style)
* 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets)
""" """
def __init__(self, block, layers, num_classes=1000, in_chans=3, def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='', cardinality=1, base_width=64, stem_width=64, stem_type='',
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0., act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', blur='', zero_init_last_bn=True, block_args=None): drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
block_args = block_args or dict() block_args = block_args or dict()
self.num_classes = num_classes self.num_classes = num_classes
deep_stem = 'deep' in stem_type deep_stem = 'deep' in stem_type
@ -392,7 +390,6 @@ class ResNet(nn.Module):
self.base_width = base_width self.base_width = base_width
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.expansion = block.expansion self.expansion = block.expansion
self.blur = 'strided' in blur
super(ResNet, self).__init__() super(ResNet, self).__init__()
# Stem # Stem
@ -414,12 +411,12 @@ class ResNet(nn.Module):
self.bn1 = norm_layer(self.inplanes) self.bn1 = norm_layer(self.inplanes)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
# Stem Pooling # Stem Pooling
if 'max' in blur : if aa_layer is not None:
self.maxpool = nn.Sequential(*[ self.maxpool = nn.Sequential(*[
nn.MaxPool2d(kernel_size=3, stride=1, padding=1), nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
BlurPool2d(channels=self.inplanes, stride=2) aa_layer(channels=self.inplanes, stride=2)
]) ])
else : else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Feature Blocks # Feature Blocks
@ -437,7 +434,7 @@ class ResNet(nn.Module):
assert output_stride == 32 assert output_stride == 32
layer_args = list(zip(channels, layers, strides, dilations)) layer_args = list(zip(channels, layers, strides, dilations))
layer_kwargs = dict( layer_kwargs = dict(
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs) self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs) self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
@ -472,7 +469,7 @@ class ResNet(nn.Module):
block_kwargs = dict( block_kwargs = dict(
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
dilation=dilation, blur=self.blur, **kwargs) dilation=dilation, **kwargs)
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)] layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)] layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
@ -1148,18 +1145,21 @@ def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-18 model with blur anti-aliasing """Constructs a ResNet-18 model with blur anti-aliasing
""" """
default_cfg = default_cfgs['resnetblur18'] default_cfg = default_cfgs['resnetblur18']
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, blur='max_strided',**kwargs) model = ResNet(
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
@register_model @register_model
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50 model with blur anti-aliasing """Constructs a ResNet-50 model with blur anti-aliasing
""" """
default_cfg = default_cfgs['resnetblur50'] default_cfg = default_cfgs['resnetblur50']
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='max_strided', **kwargs) model = ResNet(
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)

Loading…
Cancel
Save