Remove filter hack from BlurPool w/ non-persistent buffer. Use BlurPool2d instead of AntiAliasing.. for TResNet. Breaks PyTorch < 1.6.

pull/612/head
Ross Wightman 3 years ago
parent ddc743fdf8
commit 0d87650fea

@ -1,7 +1,6 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .anti_aliasing import AntiAliasDownsampleLayer
from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer

@ -1,60 +0,0 @@
import torch
import torch.nn.parallel
import torch.nn as nn
import torch.nn.functional as F
class AntiAliasDownsampleLayer(nn.Module):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
super(AntiAliasDownsampleLayer, self).__init__()
if no_jit:
self.op = Downsample(channels, filt_size, stride)
else:
self.op = DownsampleJIT(channels, filt_size, stride)
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
def forward(self, x):
return self.op(x)
@torch.jit.script
class DownsampleJIT(object):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
self.channels = channels
self.stride = stride
self.filt_size = filt_size
assert self.filt_size == 3
assert stride == 2
self.filt = {} # lazy init by device for DataParallel compat
def _create_filter(self, like: torch.Tensor):
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt)
return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
def __call__(self, input: torch.Tensor):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
filt = self.filt.get(str(input.device), self._create_filter(input))
return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1])
class Downsample(nn.Module):
def __init__(self, channels=None, filt_size=3, stride=2):
super(Downsample, self).__init__()
self.channels = channels
self.filt_size = filt_size
self.stride = stride
assert self.filt_size == 3
filt = torch.tensor([1., 2., 1.])
filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt)
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
def forward(self, input):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])

@ -3,8 +3,6 @@ BlurPool layer inspired by
- Kornia's Max_BlurPool2d
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
FIXME merge this impl with those in `anti_aliasing.py`
Hacked together by Chris Ha and Ross Wightman
"""
@ -12,7 +10,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict
from .padding import get_padding
@ -29,30 +26,17 @@ class BlurPool2d(nn.Module):
Returns:
torch.Tensor: the transformed tensor.
"""
filt: Dict[str, torch.Tensor]
def __init__(self, channels, filt_size=3, stride=2) -> None:
super(BlurPool2d, self).__init__()
assert filt_size > 1
self.channels = channels
self.filt_size = filt_size
self.stride = stride
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4
self.padding = nn.ReflectionPad2d(pad_size)
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat
self.filt = {} # lazy init by device for DataParallel compat
def _create_filter(self, like: torch.Tensor):
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device)
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1)
def _apply(self, fn):
# override nn.Module _apply, reset filter cache if used
self.filt = {}
super(BlurPool2d, self)._apply(fn)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
C = input_tensor.shape[1]
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
return F.conv2d(
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
self.register_buffer('filt', blur_filter, persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding, 'reflect')
return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1])

@ -5,16 +5,13 @@ https://arxiv.org/pdf/2003.13630.pdf
Original model: https://github.com/mrT23/TResNet
"""
import copy
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule
from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule
from .registry import register_model
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
@ -156,15 +153,12 @@ class Bottleneck(nn.Module):
class TResNet(nn.Module):
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
global_pool='fast', drop_rate=0.):
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, global_pool='fast', drop_rate=0.):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(TResNet, self).__init__()
# JIT layers
space_to_depth = SpaceToDepthModule()
aa_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit)
aa_layer = BlurPool2d
# TResnet stages
self.inplanes = int(64 * width_factor)
@ -181,7 +175,7 @@ class TResNet(nn.Module):
# body
self.body = nn.Sequential(OrderedDict([
('SpaceToDepth', space_to_depth),
('SpaceToDepth', SpaceToDepthModule()),
('conv1', conv1),
('layer1', layer1),
('layer2', layer2),

Loading…
Cancel
Save