From 0d87650fea5ce607d07204806e8143a48917d96f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 4 May 2021 16:56:28 -0700 Subject: [PATCH] Remove filter hack from BlurPool w/ non-persistent buffer. Use BlurPool2d instead of AntiAliasing.. for TResNet. Breaks PyTorch < 1.6. --- timm/models/layers/__init__.py | 1 - timm/models/layers/anti_aliasing.py | 60 ----------------------------- timm/models/layers/blur_pool.py | 32 ++++----------- timm/models/tresnet.py | 14 ++----- 4 files changed, 12 insertions(+), 95 deletions(-) delete mode 100644 timm/models/layers/anti_aliasing.py diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index ac0b6b41..eecbbde4 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1,7 +1,6 @@ from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d -from .anti_aliasing import AntiAliasDownsampleLayer from .blur_pool import BlurPool2d from .classifier import ClassifierHead, create_classifier from .cond_conv2d import CondConv2d, get_condconv_initializer diff --git a/timm/models/layers/anti_aliasing.py b/timm/models/layers/anti_aliasing.py deleted file mode 100644 index 9d3837e8..00000000 --- a/timm/models/layers/anti_aliasing.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -import torch.nn.parallel -import torch.nn as nn -import torch.nn.functional as F - - -class AntiAliasDownsampleLayer(nn.Module): - def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False): - super(AntiAliasDownsampleLayer, self).__init__() - if no_jit: - self.op = Downsample(channels, filt_size, stride) - else: - self.op = DownsampleJIT(channels, filt_size, stride) - - # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls - - def forward(self, x): - return self.op(x) - - -@torch.jit.script -class DownsampleJIT(object): - def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2): - self.channels = channels - self.stride = stride - self.filt_size = filt_size - assert self.filt_size == 3 - assert stride == 2 - self.filt = {} # lazy init by device for DataParallel compat - - def _create_filter(self, like: torch.Tensor): - filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) - filt = filt[:, None] * filt[None, :] - filt = filt / torch.sum(filt) - return filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) - - def __call__(self, input: torch.Tensor): - input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') - filt = self.filt.get(str(input.device), self._create_filter(input)) - return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1]) - - -class Downsample(nn.Module): - def __init__(self, channels=None, filt_size=3, stride=2): - super(Downsample, self).__init__() - self.channels = channels - self.filt_size = filt_size - self.stride = stride - - assert self.filt_size == 3 - filt = torch.tensor([1., 2., 1.]) - filt = filt[:, None] * filt[None, :] - filt = filt / torch.sum(filt) - - # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) - self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) - - def forward(self, input): - input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') - return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) diff --git a/timm/models/layers/blur_pool.py b/timm/models/layers/blur_pool.py index 399cbe35..ca4ce756 100644 --- a/timm/models/layers/blur_pool.py +++ b/timm/models/layers/blur_pool.py @@ -3,8 +3,6 @@ BlurPool layer inspired by - Kornia's Max_BlurPool2d - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` -FIXME merge this impl with those in `anti_aliasing.py` - Hacked together by Chris Ha and Ross Wightman """ @@ -12,7 +10,6 @@ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np -from typing import Dict from .padding import get_padding @@ -29,30 +26,17 @@ class BlurPool2d(nn.Module): Returns: torch.Tensor: the transformed tensor. """ - filt: Dict[str, torch.Tensor] - def __init__(self, channels, filt_size=3, stride=2) -> None: super(BlurPool2d, self).__init__() assert filt_size > 1 self.channels = channels self.filt_size = filt_size self.stride = stride - pad_size = [get_padding(filt_size, stride, dilation=1)] * 4 - self.padding = nn.ReflectionPad2d(pad_size) - self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat - self.filt = {} # lazy init by device for DataParallel compat - - def _create_filter(self, like: torch.Tensor): - blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device) - return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1) - - def _apply(self, fn): - # override nn.Module _apply, reset filter cache if used - self.filt = {} - super(BlurPool2d, self)._apply(fn) - - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - C = input_tensor.shape[1] - blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor)) - return F.conv2d( - self.padding(input_tensor), blur_filt, stride=self.stride, groups=C) + self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 + coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) + blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) + self.register_buffer('filt', blur_filter, persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding, 'reflect') + return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1]) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index ee1f3fc1..cec51cf4 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -5,16 +5,13 @@ https://arxiv.org/pdf/2003.13630.pdf Original model: https://github.com/mrT23/TResNet """ -import copy from collections import OrderedDict -from functools import partial import torch import torch.nn as nn -import torch.nn.functional as F from .helpers import build_model_with_cfg -from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule +from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule from .registry import register_model __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] @@ -156,15 +153,12 @@ class Bottleneck(nn.Module): class TResNet(nn.Module): - def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False, - global_pool='fast', drop_rate=0.): + def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, global_pool='fast', drop_rate=0.): self.num_classes = num_classes self.drop_rate = drop_rate super(TResNet, self).__init__() - # JIT layers - space_to_depth = SpaceToDepthModule() - aa_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit) + aa_layer = BlurPool2d # TResnet stages self.inplanes = int(64 * width_factor) @@ -181,7 +175,7 @@ class TResNet(nn.Module): # body self.body = nn.Sequential(OrderedDict([ - ('SpaceToDepth', space_to_depth), + ('SpaceToDepth', SpaceToDepthModule()), ('conv1', conv1), ('layer1', layer1), ('layer2', layer2),