From 3a287a6e764d54a984263f97923c60e8e101cc8b Mon Sep 17 00:00:00 2001 From: Chris Ha <15088501+VRandme@users.noreply.github.com> Date: Tue, 10 Mar 2020 23:27:24 +0900 Subject: [PATCH 1/7] Create blurpool.py Initial implementation of blur layer. currently tests as correct against Downsample of original github --- timm/models/layers/blurpool.py | 68 ++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 timm/models/layers/blurpool.py diff --git a/timm/models/layers/blurpool.py b/timm/models/layers/blurpool.py new file mode 100644 index 00000000..0ce4263e --- /dev/null +++ b/timm/models/layers/blurpool.py @@ -0,0 +1,68 @@ +'''independent attempt to implement + +MaxBlurPool2d in a more general fashion(separate maxpooling from BlurPool) +which was again inspired by +Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BlurPool2d(nn.Module): + r"""Creates a module that computes blurs and downsample a given feature map. + See :cite:`zhang2019shiftinvar` for more details. + Corresponds to the Downsample class, which does blurring and subsampling + Args: + channels = Number of input channels + blur_filter_size (int): filter size for blurring. currently supports either 3 or 5 (most common) + defaults to 3. + stride (int): downsampling filter stride + Shape: + Returns: + torch.Tensor: the transformed tensor. + Examples: + """ + + def __init__(self, channels=None, blur_filter_size=3, stride=2) -> None: + super(BlurPool2d, self).__init__() + assert blur_filter_size in [3, 5] + self.channels = channels + self.blur_filter_size = blur_filter_size + self.stride = stride + + if blur_filter_size == 3: + pad_size = [1] * 4 + blur_matrix = torch.Tensor([[1., 2., 1]]) / 4 # binomial kernel b2 + else: + pad_size = [2] * 4 + blur_matrix = torch.Tensor([[1., 4., 6., 4., 1.]]) / 16 # binomial filter kernel b4 + + self.padding = nn.ReflectionPad2d(pad_size) + blur_filter = blur_matrix * blur_matrix.T + self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1))) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore + if not torch.is_tensor(input_tensor): + raise TypeError("Input input type is not a torch.Tensor. Got {}" + .format(type(input_tensor))) + if not len(input_tensor.shape) == 4: + raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}" + .format(input_tensor.shape)) + # apply blur_filter on input + return F.conv2d(self.padding(input_tensor), self.blur_filter, stride=self.stride, groups=input_tensor.shape[1]) + + +###################### +# functional interface +###################### + + +'''def blur_pool2d() -> torch.Tensor: + r"""Creates a module that computes pools and blurs and downsample a given + feature map. + See :class:`~kornia.contrib.MaxBlurPool2d` for details. + """ + return BlurPool2d(kernel_size, ceil_mode)(input)''' \ No newline at end of file From ce3d82b58b2eb3e526de5df608cfc7d181b8c916 Mon Sep 17 00:00:00 2001 From: Chris Ha <15088501+VRandme@users.noreply.github.com> Date: Wed, 11 Mar 2020 22:19:10 +0900 Subject: [PATCH 2/7] Update blurpool.py clean up code for PR --- timm/models/layers/blurpool.py | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/timm/models/layers/blurpool.py b/timm/models/layers/blurpool.py index 0ce4263e..96937114 100644 --- a/timm/models/layers/blurpool.py +++ b/timm/models/layers/blurpool.py @@ -1,7 +1,7 @@ -'''independent attempt to implement - -MaxBlurPool2d in a more general fashion(separate maxpooling from BlurPool) -which was again inspired by +''' +BlurPool layer inspired by +Kornia's Max_BlurPool2d +and Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` ''' @@ -17,8 +17,7 @@ class BlurPool2d(nn.Module): Corresponds to the Downsample class, which does blurring and subsampling Args: channels = Number of input channels - blur_filter_size (int): filter size for blurring. currently supports either 3 or 5 (most common) - defaults to 3. + blur_filter_size (int): binomial filter size for blurring. currently supports 3(default) and 5. stride (int): downsampling filter stride Shape: Returns: @@ -35,16 +34,16 @@ class BlurPool2d(nn.Module): if blur_filter_size == 3: pad_size = [1] * 4 - blur_matrix = torch.Tensor([[1., 2., 1]]) / 4 # binomial kernel b2 + blur_matrix = torch.Tensor([[1., 2., 1]]) / 4 # binomial filter b2 else: pad_size = [2] * 4 - blur_matrix = torch.Tensor([[1., 4., 6., 4., 1.]]) / 16 # binomial filter kernel b4 + blur_matrix = torch.Tensor([[1., 4., 6., 4., 1.]]) / 16 # binomial filter b4 self.padding = nn.ReflectionPad2d(pad_size) blur_filter = blur_matrix * blur_matrix.T self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1))) - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore if not torch.is_tensor(input_tensor): raise TypeError("Input input type is not a torch.Tensor. Got {}" .format(type(input_tensor))) @@ -52,17 +51,4 @@ class BlurPool2d(nn.Module): raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}" .format(input_tensor.shape)) # apply blur_filter on input - return F.conv2d(self.padding(input_tensor), self.blur_filter, stride=self.stride, groups=input_tensor.shape[1]) - - -###################### -# functional interface -###################### - - -'''def blur_pool2d() -> torch.Tensor: - r"""Creates a module that computes pools and blurs and downsample a given - feature map. - See :class:`~kornia.contrib.MaxBlurPool2d` for details. - """ - return BlurPool2d(kernel_size, ceil_mode)(input)''' \ No newline at end of file + return F.conv2d(self.padding(input_tensor), self.blur_filter, stride=self.stride, groups=input_tensor.shape[1]) \ No newline at end of file From acd1b6cccd9a333bd790b2d00d646a966e710151 Mon Sep 17 00:00:00 2001 From: Chris Ha <15088501+VRandme@users.noreply.github.com> Date: Sun, 22 Mar 2020 22:42:55 +0900 Subject: [PATCH 3/7] Implement Functional Blur on resnet.py 1. add ResNet argument blur='' 2. implement blur for maxpool and strided convs in downsampling blocks --- timm/models/layers/__init__.py | 1 + timm/models/layers/blurpool.py | 2 +- timm/models/resnet.py | 79 +++++++++++++++++++++++++++++----- 3 files changed, 70 insertions(+), 12 deletions(-) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index f012c3cf..33450483 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -15,3 +15,4 @@ from .adaptive_avgmax_pool import \ from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .blurpool import BlurPool2d diff --git a/timm/models/layers/blurpool.py b/timm/models/layers/blurpool.py index 96937114..0b37a90c 100644 --- a/timm/models/layers/blurpool.py +++ b/timm/models/layers/blurpool.py @@ -17,7 +17,7 @@ class BlurPool2d(nn.Module): Corresponds to the Downsample class, which does blurring and subsampling Args: channels = Number of input channels - blur_filter_size (int): binomial filter size for blurring. currently supports 3(default) and 5. + blur_filter_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. stride (int): downsampling filter stride Shape: Returns: diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 0013cbe0..057eca6c 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -12,7 +12,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn +from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -104,6 +104,8 @@ default_cfgs = { interpolation='bicubic'), 'ecaresnet18': _cfg(), 'ecaresnet50': _cfg(), + 'resnetblur18': _cfg(), + 'resnetblur50': _cfg() } @@ -117,7 +119,7 @@ class BasicBlock(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, drop_block=None, drop_path=None): + attn_layer=None, drop_block=None, drop_path=None, blur=False): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' @@ -125,10 +127,19 @@ class BasicBlock(nn.Module): first_planes = planes // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation + self.blur = blur - self.conv1 = nn.Conv2d( + if blur and stride==2: + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=1, padding=first_dilation, + dilation=first_dilation, bias=False) + self.blurpool=BlurPool2d(channels=first_planes) + else: + self.conv1 = nn.Conv2d( inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, dilation=first_dilation, bias=False) + self.blurpool = None + self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( @@ -154,7 +165,11 @@ class BasicBlock(nn.Module): x = self.bn1(x) if self.drop_block is not None: x = self.drop_block(x) - x = self.act1(x) + if self.blurpool is not None: + x = self.act1(x) + x = self.blurpool(x) + else: + x = self.act1(x) x = self.conv2(x) x = self.bn2(x) @@ -181,20 +196,30 @@ class Bottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, drop_block=None, drop_path=None): + attn_layer=None, drop_block=None, drop_path=None, blur=False): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation + self.blur = blur self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) - self.conv2 = nn.Conv2d( - first_planes, width, kernel_size=3, stride=stride, - padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) + + if blur and stride==2: + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=1, + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) + self.blurpool = BlurPool2d(channels=width) + else: + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=stride, + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) + self.blurpool = None + self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) @@ -345,12 +370,19 @@ class ResNet(nn.Module): Dropout probability before classifier, for training global_pool : str, default 'avg' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + blur : str, default '' + Location of Blurring: + * '', default - Not applied + * 'max' - only stem layer MaxPool will be blurred + * 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style) + * 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets) + """ def __init__(self, block, layers, num_classes=1000, in_chans=3, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0., - drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): + drop_block_rate=0., global_pool='avg', blur='', zero_init_last_bn=True, block_args=None): block_args = block_args or dict() self.num_classes = num_classes deep_stem = 'deep' in stem_type @@ -359,6 +391,7 @@ class ResNet(nn.Module): self.base_width = base_width self.drop_rate = drop_rate self.expansion = block.expansion + self.blur = 'strided' in blur super(ResNet, self).__init__() # Stem @@ -379,7 +412,13 @@ class ResNet(nn.Module): self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.act1 = act_layer(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + # Stem Blur + if 'max' in blur : + self.maxpool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + BlurPool2d(channels=self.inplanes)]) + else : + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks dp = DropPath(drop_path_rate) if drop_path_rate else None @@ -432,7 +471,7 @@ class ResNet(nn.Module): block_kwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, dilation=dilation, **kwargs) - layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)] + layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, blur=self.blur, **block_kwargs)] self.inplanes = planes * block.expansion layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)] @@ -1022,3 +1061,21 @@ def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model + +@register_model +def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. With original style blur + """ + default_cfg = default_cfgs['resnetblur18'] + model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, blur='max_strided',**kwargs) + model.default_cfg = default_cfg + return model + +@register_model +def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-50 model. With assembled-cnn style blur + """ + default_cfg = default_cfgs['resnetblur18'] + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='strided', **kwargs) + model.default_cfg = default_cfg + return model \ No newline at end of file From 6cdeca24a389cce7693e043b89db04e230b3a405 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 1 Apr 2020 12:01:59 -0700 Subject: [PATCH 4/7] Some cleanup and fixes for initial BlurPool impl. Still some testing and tweaks to go... --- timm/models/layers/blurpool.py | 43 +++++++++++++++------------- timm/models/resnet.py | 52 +++++++++++++--------------------- 2 files changed, 43 insertions(+), 52 deletions(-) diff --git a/timm/models/layers/blurpool.py b/timm/models/layers/blurpool.py index 0b37a90c..a12274d8 100644 --- a/timm/models/layers/blurpool.py +++ b/timm/models/layers/blurpool.py @@ -1,14 +1,17 @@ -''' +""" BlurPool layer inspired by -Kornia's Max_BlurPool2d -and -Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + - Kornia's Max_BlurPool2d + - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + +Hacked together by Chris Ha and Ross Wightman +""" -''' import torch import torch.nn as nn import torch.nn.functional as F +import numpy as np +from .padding import get_padding class BlurPool2d(nn.Module): @@ -25,30 +28,30 @@ class BlurPool2d(nn.Module): Examples: """ - def __init__(self, channels=None, blur_filter_size=3, stride=2) -> None: + def __init__(self, channels, blur_filter_size=3, stride=2) -> None: super(BlurPool2d, self).__init__() - assert blur_filter_size in [3, 5] + assert blur_filter_size > 1 self.channels = channels self.blur_filter_size = blur_filter_size self.stride = stride - if blur_filter_size == 3: - pad_size = [1] * 4 - blur_matrix = torch.Tensor([[1., 2., 1]]) / 4 # binomial filter b2 - else: - pad_size = [2] * 4 - blur_matrix = torch.Tensor([[1., 4., 6., 4., 1.]]) / 16 # binomial filter b4 - + pad_size = [get_padding(blur_filter_size, stride, dilation=1)] * 4 self.padding = nn.ReflectionPad2d(pad_size) - blur_filter = blur_matrix * blur_matrix.T + + blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs + blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :]) + # FIXME figure a clean hack to prevent the filter from getting saved in weights, but still + # plays nice with recursive module apply for fn like .cuda(), .type(), etc -RW self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1))) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore if not torch.is_tensor(input_tensor): - raise TypeError("Input input type is not a torch.Tensor. Got {}" - .format(type(input_tensor))) + raise TypeError("Input input type is not a torch.Tensor. Got {}".format(type(input_tensor))) if not len(input_tensor.shape) == 4: - raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}" - .format(input_tensor.shape)) + raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}".format(input_tensor.shape)) # apply blur_filter on input - return F.conv2d(self.padding(input_tensor), self.blur_filter, stride=self.stride, groups=input_tensor.shape[1]) \ No newline at end of file + return F.conv2d( + self.padding(input_tensor), + self.blur_filter.type(input_tensor.dtype), + stride=self.stride, + groups=input_tensor.shape[1]) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 057eca6c..fdb8097c 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -127,21 +127,14 @@ class BasicBlock(nn.Module): first_planes = planes // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation - self.blur = blur - if blur and stride==2: - self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=1, padding=first_dilation, - dilation=first_dilation, bias=False) - self.blurpool=BlurPool2d(channels=first_planes) - else: - self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=1 if blur else stride, padding=first_dilation, dilation=first_dilation, bias=False) - self.blurpool = None - self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) + self.blurpool = BlurPool2d(channels=first_planes) if stride == 2 and blur else None + self.conv2 = nn.Conv2d( first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) self.bn2 = norm_layer(outplanes) @@ -165,11 +158,9 @@ class BasicBlock(nn.Module): x = self.bn1(x) if self.drop_block is not None: x = self.drop_block(x) + x = self.act1(x) if self.blurpool is not None: - x = self.act1(x) x = self.blurpool(x) - else: - x = self.act1(x) x = self.conv2(x) x = self.bn2(x) @@ -209,19 +200,13 @@ class Bottleneck(nn.Module): self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) - if blur and stride==2: - self.conv2 = nn.Conv2d( - first_planes, width, kernel_size=3, stride=1, - padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) - self.blurpool = BlurPool2d(channels=width) - else: - self.conv2 = nn.Conv2d( - first_planes, width, kernel_size=3, stride=stride, - padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) - self.blurpool = None - + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=1 if blur else stride, + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) + self.blurpool = BlurPool2d(channels=width) if stride == 2 and blur else None + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) @@ -251,6 +236,8 @@ class Bottleneck(nn.Module): if self.drop_block is not None: x = self.drop_block(x) x = self.act2(x) + if self.blurpool is not None: + x = self.blurpool(x) x = self.conv3(x) x = self.bn3(x) @@ -412,11 +399,12 @@ class ResNet(nn.Module): self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.act1 = act_layer(inplace=True) - # Stem Blur + # Stem Pooling if 'max' in blur : self.maxpool = nn.Sequential(*[ - nn.MaxPool2d(kernel_size=3, stride=1, padding=1), - BlurPool2d(channels=self.inplanes)]) + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + BlurPool2d(channels=self.inplanes, stride=2) + ]) else : self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -470,8 +458,8 @@ class ResNet(nn.Module): block_kwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - dilation=dilation, **kwargs) - layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, blur=self.blur, **block_kwargs)] + dilation=dilation, blur=self.blur, **kwargs) + layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)] self.inplanes = planes * block.expansion layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)] @@ -1075,7 +1063,7 @@ def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-50 model. With assembled-cnn style blur """ - default_cfg = default_cfgs['resnetblur18'] - model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='strided', **kwargs) + default_cfg = default_cfgs['resnetblur50'] + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='max_strided', **kwargs) model.default_cfg = default_cfg return model \ No newline at end of file From f17b42bc33eca08308192f4079f8f609dada81db Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 1 Apr 2020 13:15:06 -0700 Subject: [PATCH 5/7] Blur filter no longer a buffer --- timm/models/layers/blurpool.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/timm/models/layers/blurpool.py b/timm/models/layers/blurpool.py index a12274d8..6df2b748 100644 --- a/timm/models/layers/blurpool.py +++ b/timm/models/layers/blurpool.py @@ -40,18 +40,16 @@ class BlurPool2d(nn.Module): blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :]) - # FIXME figure a clean hack to prevent the filter from getting saved in weights, but still - # plays nice with recursive module apply for fn like .cuda(), .type(), etc -RW - self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1))) + self.blur_filter = blur_filter[None, None, :, :] + + def _apply(self, fn): + # override nn.Module _apply to prevent need for blur_filter to be registered as a buffer, + # this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected + super(BlurPool2d, self)._apply(fn) + self.blur_filter = fn(self.blur_filter) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore - if not torch.is_tensor(input_tensor): - raise TypeError("Input input type is not a torch.Tensor. Got {}".format(type(input_tensor))) - if not len(input_tensor.shape) == 4: - raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}".format(input_tensor.shape)) - # apply blur_filter on input return F.conv2d( self.padding(input_tensor), - self.blur_filter.type(input_tensor.dtype), - stride=self.stride, - groups=input_tensor.shape[1]) + self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), + stride=self.stride, groups=input_tensor.shape[1]) From 1a9ab07307a5dccaec08566bc599b3c36c89dae8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 1 Apr 2020 13:19:08 -0700 Subject: [PATCH 6/7] One too many changes at a time, fix missing C --- timm/models/layers/blurpool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/layers/blurpool.py b/timm/models/layers/blurpool.py index 6df2b748..57af3e0e 100644 --- a/timm/models/layers/blurpool.py +++ b/timm/models/layers/blurpool.py @@ -49,7 +49,7 @@ class BlurPool2d(nn.Module): self.blur_filter = fn(self.blur_filter) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore + C = input_tensor.shape[1] return F.conv2d( self.padding(input_tensor), - self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), - stride=self.stride, groups=input_tensor.shape[1]) + self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), stride=self.stride, groups=C) From 2681a8d618ba85085369bcf22a89c6fb4c5076be Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 May 2020 17:00:21 -0700 Subject: [PATCH 7/7] Final blurpool2d cleanup and add resnetblur50 weights, match tresnet Downsample arg order to BlurPool2d for interop --- timm/models/helpers.py | 4 +- timm/models/layers/__init__.py | 2 +- timm/models/layers/anti_aliasing.py | 17 ++++---- timm/models/layers/blur_pool.py | 58 ++++++++++++++++++++++++++ timm/models/layers/blurpool.py | 55 ------------------------- timm/models/resnet.py | 64 ++++++++++++++--------------- 6 files changed, 101 insertions(+), 99 deletions(-) create mode 100644 timm/models/layers/blur_pool.py delete mode 100644 timm/models/layers/blurpool.py diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 3183f631..3baad3bf 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -31,9 +31,9 @@ def load_state_dict(checkpoint_path, use_ema=False): raise FileNotFoundError() -def load_checkpoint(model, checkpoint_path, use_ema=False): +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): state_dict = load_state_dict(checkpoint_path, use_ema) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=strict) def resume_checkpoint(model, checkpoint_path): diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 0e9a957f..4f84bb9e 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -18,4 +18,4 @@ from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .anti_aliasing import AntiAliasDownsampleLayer from .space_to_depth import SpaceToDepthModule -from .blurpool import BlurPool2d +from .blur_pool import BlurPool2d diff --git a/timm/models/layers/anti_aliasing.py b/timm/models/layers/anti_aliasing.py index 38f96ee3..9d3837e8 100644 --- a/timm/models/layers/anti_aliasing.py +++ b/timm/models/layers/anti_aliasing.py @@ -5,12 +5,12 @@ import torch.nn.functional as F class AntiAliasDownsampleLayer(nn.Module): - def __init__(self, no_jit: bool = False, filt_size: int = 3, stride: int = 2, channels: int = 0): + def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False): super(AntiAliasDownsampleLayer, self).__init__() if no_jit: - self.op = Downsample(filt_size, stride, channels) + self.op = Downsample(channels, filt_size, stride) else: - self.op = DownsampleJIT(filt_size, stride, channels) + self.op = DownsampleJIT(channels, filt_size, stride) # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls @@ -20,10 +20,10 @@ class AntiAliasDownsampleLayer(nn.Module): @torch.jit.script class DownsampleJIT(object): - def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0): + def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2): + self.channels = channels self.stride = stride self.filt_size = filt_size - self.channels = channels assert self.filt_size == 3 assert stride == 2 self.filt = {} # lazy init by device for DataParallel compat @@ -32,8 +32,7 @@ class DownsampleJIT(object): filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) filt = filt[:, None] * filt[None, :] filt = filt / torch.sum(filt) - filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) - return filt + return filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) def __call__(self, input: torch.Tensor): input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') @@ -42,11 +41,11 @@ class DownsampleJIT(object): class Downsample(nn.Module): - def __init__(self, filt_size=3, stride=2, channels=None): + def __init__(self, channels=None, filt_size=3, stride=2): super(Downsample, self).__init__() + self.channels = channels self.filt_size = filt_size self.stride = stride - self.channels = channels assert self.filt_size == 3 filt = torch.tensor([1., 2., 1.]) diff --git a/timm/models/layers/blur_pool.py b/timm/models/layers/blur_pool.py new file mode 100644 index 00000000..399cbe35 --- /dev/null +++ b/timm/models/layers/blur_pool.py @@ -0,0 +1,58 @@ +""" +BlurPool layer inspired by + - Kornia's Max_BlurPool2d + - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + +FIXME merge this impl with those in `anti_aliasing.py` + +Hacked together by Chris Ha and Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from typing import Dict +from .padding import get_padding + + +class BlurPool2d(nn.Module): + r"""Creates a module that computes blurs and downsample a given feature map. + See :cite:`zhang2019shiftinvar` for more details. + Corresponds to the Downsample class, which does blurring and subsampling + + Args: + channels = Number of input channels + filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. + stride (int): downsampling filter stride + + Returns: + torch.Tensor: the transformed tensor. + """ + filt: Dict[str, torch.Tensor] + + def __init__(self, channels, filt_size=3, stride=2) -> None: + super(BlurPool2d, self).__init__() + assert filt_size > 1 + self.channels = channels + self.filt_size = filt_size + self.stride = stride + pad_size = [get_padding(filt_size, stride, dilation=1)] * 4 + self.padding = nn.ReflectionPad2d(pad_size) + self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat + self.filt = {} # lazy init by device for DataParallel compat + + def _create_filter(self, like: torch.Tensor): + blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device) + return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1) + + def _apply(self, fn): + # override nn.Module _apply, reset filter cache if used + self.filt = {} + super(BlurPool2d, self)._apply(fn) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + C = input_tensor.shape[1] + blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor)) + return F.conv2d( + self.padding(input_tensor), blur_filt, stride=self.stride, groups=C) diff --git a/timm/models/layers/blurpool.py b/timm/models/layers/blurpool.py deleted file mode 100644 index 57af3e0e..00000000 --- a/timm/models/layers/blurpool.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -BlurPool layer inspired by - - Kornia's Max_BlurPool2d - - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` - -Hacked together by Chris Ha and Ross Wightman -""" - - -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from .padding import get_padding - - -class BlurPool2d(nn.Module): - r"""Creates a module that computes blurs and downsample a given feature map. - See :cite:`zhang2019shiftinvar` for more details. - Corresponds to the Downsample class, which does blurring and subsampling - Args: - channels = Number of input channels - blur_filter_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. - stride (int): downsampling filter stride - Shape: - Returns: - torch.Tensor: the transformed tensor. - Examples: - """ - - def __init__(self, channels, blur_filter_size=3, stride=2) -> None: - super(BlurPool2d, self).__init__() - assert blur_filter_size > 1 - self.channels = channels - self.blur_filter_size = blur_filter_size - self.stride = stride - - pad_size = [get_padding(blur_filter_size, stride, dilation=1)] * 4 - self.padding = nn.ReflectionPad2d(pad_size) - - blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs - blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :]) - self.blur_filter = blur_filter[None, None, :, :] - - def _apply(self, fn): - # override nn.Module _apply to prevent need for blur_filter to be registered as a buffer, - # this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected - super(BlurPool2d, self)._apply(fn) - self.blur_filter = fn(self.blur_filter) - - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore - C = input_tensor.shape[1] - return F.conv2d( - self.padding(input_tensor), - self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), stride=self.stride, groups=C) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 55c38dff..4e865705 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -118,8 +118,11 @@ default_cfgs = { 'ecaresnet101d_pruned': _cfg( url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', interpolation='bicubic'), - 'resnetblur18': _cfg(), - 'resnetblur50': _cfg() + 'resnetblur18': _cfg( + interpolation='bicubic'), + 'resnetblur50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth', + interpolation='bicubic') } @@ -133,7 +136,7 @@ class BasicBlock(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, drop_block=None, drop_path=None, blur=False): + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' @@ -141,13 +144,14 @@ class BasicBlock(nn.Module): first_planes = planes // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation + use_aa = aa_layer is not None self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=1 if blur else stride, padding=first_dilation, + inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, dilation=first_dilation, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) - self.blurpool = BlurPool2d(channels=first_planes) if stride == 2 and blur else None + self.aa = aa_layer(channels=first_planes) if stride == 2 and use_aa else None self.conv2 = nn.Conv2d( first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) @@ -173,8 +177,8 @@ class BasicBlock(nn.Module): if self.drop_block is not None: x = self.drop_block(x) x = self.act1(x) - if self.blurpool is not None: - x = self.blurpool(x) + if self.aa is not None: + x = self.aa(x) x = self.conv2(x) x = self.bn2(x) @@ -201,25 +205,25 @@ class Bottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, drop_block=None, drop_path=None, blur=False): + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation - self.blur = blur + use_aa = aa_layer is not None self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( - first_planes, width, kernel_size=3, stride=1 if blur else stride, + first_planes, width, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) - self.blurpool = BlurPool2d(channels=width) if stride == 2 and blur else None + self.aa = aa_layer(channels=width) if stride == 2 and use_aa else None self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) @@ -250,8 +254,8 @@ class Bottleneck(nn.Module): if self.drop_block is not None: x = self.drop_block(x) x = self.act2(x) - if self.blurpool is not None: - x = self.blurpool(x) + if self.aa is not None: + x = self.aa(x) x = self.conv3(x) x = self.bn3(x) @@ -365,25 +369,19 @@ class ResNet(nn.Module): Whether to use average pooling for projection skip connection between stages/downsample. output_stride : int, default 32 Set the output stride of the network, 32, 16, or 8. Typically used in segmentation. - act_layer : class, activation layer - norm_layer : class, normalization layer + act_layer : nn.Module, activation layer + norm_layer : nn.Module, normalization layer + aa_layer : nn.Module, anti-aliasing layer drop_rate : float, default 0. Dropout probability before classifier, for training global_pool : str, default 'avg' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' - blur : str, default '' - Location of Blurring: - * '', default - Not applied - * 'max' - only stem layer MaxPool will be blurred - * 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style) - * 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets) - """ def __init__(self, block, layers, num_classes=1000, in_chans=3, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0., - drop_block_rate=0., global_pool='avg', blur='', zero_init_last_bn=True, block_args=None): + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., + drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): block_args = block_args or dict() self.num_classes = num_classes deep_stem = 'deep' in stem_type @@ -392,7 +390,6 @@ class ResNet(nn.Module): self.base_width = base_width self.drop_rate = drop_rate self.expansion = block.expansion - self.blur = 'strided' in blur super(ResNet, self).__init__() # Stem @@ -414,12 +411,12 @@ class ResNet(nn.Module): self.bn1 = norm_layer(self.inplanes) self.act1 = act_layer(inplace=True) # Stem Pooling - if 'max' in blur : + if aa_layer is not None: self.maxpool = nn.Sequential(*[ nn.MaxPool2d(kernel_size=3, stride=1, padding=1), - BlurPool2d(channels=self.inplanes, stride=2) + aa_layer(channels=self.inplanes, stride=2) ]) - else : + else: self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks @@ -437,7 +434,7 @@ class ResNet(nn.Module): assert output_stride == 32 layer_args = list(zip(channels, layers, strides, dilations)) layer_kwargs = dict( - reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, + reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs) self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs) @@ -472,7 +469,7 @@ class ResNet(nn.Module): block_kwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - dilation=dilation, blur=self.blur, **kwargs) + dilation=dilation, **kwargs) layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)] self.inplanes = planes * block.expansion layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)] @@ -1148,18 +1145,21 @@ def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-18 model with blur anti-aliasing """ default_cfg = default_cfgs['resnetblur18'] - model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, blur='max_strided',**kwargs) + model = ResNet( + BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model + @register_model def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-50 model with blur anti-aliasing """ default_cfg = default_cfgs['resnetblur50'] - model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='max_strided', **kwargs) + model = ResNet( + Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans)