Remove filter hack from BlurPool w/ non-persistent buffer. Use BlurPool2d instead of AntiAliasing.. for TResNet. Breaks PyTorch < 1.6.

pull/612/head
Ross Wightman 3 years ago
parent ddc743fdf8
commit 0d87650fea

@ -1,7 +1,6 @@
from .activations import * from .activations import *
from .adaptive_avgmax_pool import \ from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .anti_aliasing import AntiAliasDownsampleLayer
from .blur_pool import BlurPool2d from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer from .cond_conv2d import CondConv2d, get_condconv_initializer

@ -1,60 +0,0 @@
import torch
import torch.nn.parallel
import torch.nn as nn
import torch.nn.functional as F
class AntiAliasDownsampleLayer(nn.Module):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
super(AntiAliasDownsampleLayer, self).__init__()
if no_jit:
self.op = Downsample(channels, filt_size, stride)
else:
self.op = DownsampleJIT(channels, filt_size, stride)
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
def forward(self, x):
return self.op(x)
@torch.jit.script
class DownsampleJIT(object):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
self.channels = channels
self.stride = stride
self.filt_size = filt_size
assert self.filt_size == 3
assert stride == 2
self.filt = {} # lazy init by device for DataParallel compat
def _create_filter(self, like: torch.Tensor):
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt)
return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
def __call__(self, input: torch.Tensor):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
filt = self.filt.get(str(input.device), self._create_filter(input))
return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1])
class Downsample(nn.Module):
def __init__(self, channels=None, filt_size=3, stride=2):
super(Downsample, self).__init__()
self.channels = channels
self.filt_size = filt_size
self.stride = stride
assert self.filt_size == 3
filt = torch.tensor([1., 2., 1.])
filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt)
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
def forward(self, input):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])

@ -3,8 +3,6 @@ BlurPool layer inspired by
- Kornia's Max_BlurPool2d - Kornia's Max_BlurPool2d
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
FIXME merge this impl with those in `anti_aliasing.py`
Hacked together by Chris Ha and Ross Wightman Hacked together by Chris Ha and Ross Wightman
""" """
@ -12,7 +10,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from typing import Dict
from .padding import get_padding from .padding import get_padding
@ -29,30 +26,17 @@ class BlurPool2d(nn.Module):
Returns: Returns:
torch.Tensor: the transformed tensor. torch.Tensor: the transformed tensor.
""" """
filt: Dict[str, torch.Tensor]
def __init__(self, channels, filt_size=3, stride=2) -> None: def __init__(self, channels, filt_size=3, stride=2) -> None:
super(BlurPool2d, self).__init__() super(BlurPool2d, self).__init__()
assert filt_size > 1 assert filt_size > 1
self.channels = channels self.channels = channels
self.filt_size = filt_size self.filt_size = filt_size
self.stride = stride self.stride = stride
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4 self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
self.padding = nn.ReflectionPad2d(pad_size) coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
self.filt = {} # lazy init by device for DataParallel compat self.register_buffer('filt', blur_filter, persistent=False)
def _create_filter(self, like: torch.Tensor): def forward(self, x: torch.Tensor) -> torch.Tensor:
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device) x = F.pad(x, self.padding, 'reflect')
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1) return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1])
def _apply(self, fn):
# override nn.Module _apply, reset filter cache if used
self.filt = {}
super(BlurPool2d, self)._apply(fn)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
C = input_tensor.shape[1]
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
return F.conv2d(
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)

@ -5,16 +5,13 @@ https://arxiv.org/pdf/2003.13630.pdf
Original model: https://github.com/mrT23/TResNet Original model: https://github.com/mrT23/TResNet
""" """
import copy
from collections import OrderedDict from collections import OrderedDict
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule
from .registry import register_model from .registry import register_model
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
@ -156,15 +153,12 @@ class Bottleneck(nn.Module):
class TResNet(nn.Module): class TResNet(nn.Module):
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False, def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, global_pool='fast', drop_rate=0.):
global_pool='fast', drop_rate=0.):
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
super(TResNet, self).__init__() super(TResNet, self).__init__()
# JIT layers aa_layer = BlurPool2d
space_to_depth = SpaceToDepthModule()
aa_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit)
# TResnet stages # TResnet stages
self.inplanes = int(64 * width_factor) self.inplanes = int(64 * width_factor)
@ -181,7 +175,7 @@ class TResNet(nn.Module):
# body # body
self.body = nn.Sequential(OrderedDict([ self.body = nn.Sequential(OrderedDict([
('SpaceToDepth', space_to_depth), ('SpaceToDepth', SpaceToDepthModule()),
('conv1', conv1), ('conv1', conv1),
('layer1', layer1), ('layer1', layer1),
('layer2', layer2), ('layer2', layer2),

Loading…
Cancel
Save