Merge branch 'VRandme-blur'

pull/136/head
Ross Wightman 5 years ago
commit 28739bb721

@ -31,9 +31,9 @@ def load_state_dict(checkpoint_path, use_ema=False):
raise FileNotFoundError() raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False): def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
state_dict = load_state_dict(checkpoint_path, use_ema) state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict) model.load_state_dict(state_dict, strict=strict)
def resume_checkpoint(model, checkpoint_path): def resume_checkpoint(model, checkpoint_path):

@ -18,3 +18,4 @@ from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .anti_aliasing import AntiAliasDownsampleLayer from .anti_aliasing import AntiAliasDownsampleLayer
from .space_to_depth import SpaceToDepthModule from .space_to_depth import SpaceToDepthModule
from .blur_pool import BlurPool2d

@ -5,12 +5,12 @@ import torch.nn.functional as F
class AntiAliasDownsampleLayer(nn.Module): class AntiAliasDownsampleLayer(nn.Module):
def __init__(self, no_jit: bool = False, filt_size: int = 3, stride: int = 2, channels: int = 0): def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
super(AntiAliasDownsampleLayer, self).__init__() super(AntiAliasDownsampleLayer, self).__init__()
if no_jit: if no_jit:
self.op = Downsample(filt_size, stride, channels) self.op = Downsample(channels, filt_size, stride)
else: else:
self.op = DownsampleJIT(filt_size, stride, channels) self.op = DownsampleJIT(channels, filt_size, stride)
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
@ -20,10 +20,10 @@ class AntiAliasDownsampleLayer(nn.Module):
@torch.jit.script @torch.jit.script
class DownsampleJIT(object): class DownsampleJIT(object):
def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0): def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
self.channels = channels
self.stride = stride self.stride = stride
self.filt_size = filt_size self.filt_size = filt_size
self.channels = channels
assert self.filt_size == 3 assert self.filt_size == 3
assert stride == 2 assert stride == 2
self.filt = {} # lazy init by device for DataParallel compat self.filt = {} # lazy init by device for DataParallel compat
@ -32,8 +32,7 @@ class DownsampleJIT(object):
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
filt = filt[:, None] * filt[None, :] filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt) filt = filt / torch.sum(filt)
filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
return filt
def __call__(self, input: torch.Tensor): def __call__(self, input: torch.Tensor):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
@ -42,11 +41,11 @@ class DownsampleJIT(object):
class Downsample(nn.Module): class Downsample(nn.Module):
def __init__(self, filt_size=3, stride=2, channels=None): def __init__(self, channels=None, filt_size=3, stride=2):
super(Downsample, self).__init__() super(Downsample, self).__init__()
self.channels = channels
self.filt_size = filt_size self.filt_size = filt_size
self.stride = stride self.stride = stride
self.channels = channels
assert self.filt_size == 3 assert self.filt_size == 3
filt = torch.tensor([1., 2., 1.]) filt = torch.tensor([1., 2., 1.])

@ -0,0 +1,58 @@
"""
BlurPool layer inspired by
- Kornia's Max_BlurPool2d
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
FIXME merge this impl with those in `anti_aliasing.py`
Hacked together by Chris Ha and Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict
from .padding import get_padding
class BlurPool2d(nn.Module):
r"""Creates a module that computes blurs and downsample a given feature map.
See :cite:`zhang2019shiftinvar` for more details.
Corresponds to the Downsample class, which does blurring and subsampling
Args:
channels = Number of input channels
filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
stride (int): downsampling filter stride
Returns:
torch.Tensor: the transformed tensor.
"""
filt: Dict[str, torch.Tensor]
def __init__(self, channels, filt_size=3, stride=2) -> None:
super(BlurPool2d, self).__init__()
assert filt_size > 1
self.channels = channels
self.filt_size = filt_size
self.stride = stride
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4
self.padding = nn.ReflectionPad2d(pad_size)
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat
self.filt = {} # lazy init by device for DataParallel compat
def _create_filter(self, like: torch.Tensor):
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device)
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1)
def _apply(self, fn):
# override nn.Module _apply, reset filter cache if used
self.filt = {}
super(BlurPool2d, self)._apply(fn)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
C = input_tensor.shape[1]
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
return F.conv2d(
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)

@ -12,7 +12,7 @@ import torch.nn.functional as F
from .registry import register_model from .registry import register_model
from .helpers import load_pretrained, adapt_model_from_file from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
@ -118,6 +118,11 @@ default_cfgs = {
'ecaresnet101d_pruned': _cfg( 'ecaresnet101d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
interpolation='bicubic'), interpolation='bicubic'),
'resnetblur18': _cfg(
interpolation='bicubic'),
'resnetblur50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth',
interpolation='bicubic')
} }
@ -131,7 +136,7 @@ class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, drop_block=None, drop_path=None): attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
@ -139,12 +144,15 @@ class BasicBlock(nn.Module):
first_planes = planes // reduce_first first_planes = planes // reduce_first
outplanes = planes * self.expansion outplanes = planes * self.expansion
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
use_aa = aa_layer is not None
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
dilation=first_dilation, bias=False) dilation=first_dilation, bias=False)
self.bn1 = norm_layer(first_planes) self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.aa = aa_layer(channels=first_planes) if stride == 2 and use_aa else None
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
self.bn2 = norm_layer(outplanes) self.bn2 = norm_layer(outplanes)
@ -169,6 +177,8 @@ class BasicBlock(nn.Module):
if self.drop_block is not None: if self.drop_block is not None:
x = self.drop_block(x) x = self.drop_block(x)
x = self.act1(x) x = self.act1(x)
if self.aa is not None:
x = self.aa(x)
x = self.conv2(x) x = self.conv2(x)
x = self.bn2(x) x = self.bn2(x)
@ -195,22 +205,26 @@ class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, drop_block=None, drop_path=None): attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality) width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first first_planes = width // reduce_first
outplanes = planes * self.expansion outplanes = planes * self.expansion
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
use_aa = aa_layer is not None
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(first_planes) self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
first_planes, width, kernel_size=3, stride=stride, first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width) self.bn2 = norm_layer(width)
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.aa = aa_layer(channels=width) if stride == 2 and use_aa else None
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes) self.bn3 = norm_layer(outplanes)
@ -240,6 +254,8 @@ class Bottleneck(nn.Module):
if self.drop_block is not None: if self.drop_block is not None:
x = self.drop_block(x) x = self.drop_block(x)
x = self.act2(x) x = self.act2(x)
if self.aa is not None:
x = self.aa(x)
x = self.conv3(x) x = self.conv3(x)
x = self.bn3(x) x = self.bn3(x)
@ -353,8 +369,9 @@ class ResNet(nn.Module):
Whether to use average pooling for projection skip connection between stages/downsample. Whether to use average pooling for projection skip connection between stages/downsample.
output_stride : int, default 32 output_stride : int, default 32
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation. Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
act_layer : class, activation layer act_layer : nn.Module, activation layer
norm_layer : class, normalization layer norm_layer : nn.Module, normalization layer
aa_layer : nn.Module, anti-aliasing layer
drop_rate : float, default 0. drop_rate : float, default 0.
Dropout probability before classifier, for training Dropout probability before classifier, for training
global_pool : str, default 'avg' global_pool : str, default 'avg'
@ -363,7 +380,7 @@ class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, in_chans=3, def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='', cardinality=1, base_width=64, stem_width=64, stem_type='',
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0., act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
block_args = block_args or dict() block_args = block_args or dict()
self.num_classes = num_classes self.num_classes = num_classes
@ -393,7 +410,14 @@ class ResNet(nn.Module):
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes) self.bn1 = norm_layer(self.inplanes)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Stem Pooling
if aa_layer is not None:
self.maxpool = nn.Sequential(*[
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
aa_layer(channels=self.inplanes, stride=2)
])
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Feature Blocks # Feature Blocks
dp = DropPath(drop_path_rate) if drop_path_rate else None dp = DropPath(drop_path_rate) if drop_path_rate else None
@ -410,7 +434,7 @@ class ResNet(nn.Module):
assert output_stride == 32 assert output_stride == 32
layer_args = list(zip(channels, layers, strides, dilations)) layer_args = list(zip(channels, layers, strides, dilations))
layer_kwargs = dict( layer_kwargs = dict(
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs) self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs) self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
@ -1114,3 +1138,29 @@ def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwarg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
@register_model
def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-18 model with blur anti-aliasing
"""
default_cfg = default_cfgs['resnetblur18']
model = ResNet(
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50 model with blur anti-aliasing
"""
default_cfg = default_cfgs['resnetblur50']
model = ResNet(
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

Loading…
Cancel
Save