Final blurpool2d cleanup and add resnetblur50 weights, match tresnet Downsample arg order to BlurPool2d for interop
parent
9590f301a9
commit
2681a8d618
@ -0,0 +1,58 @@
|
||||
"""
|
||||
BlurPool layer inspired by
|
||||
- Kornia's Max_BlurPool2d
|
||||
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
|
||||
|
||||
FIXME merge this impl with those in `anti_aliasing.py`
|
||||
|
||||
Hacked together by Chris Ha and Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
from .padding import get_padding
|
||||
|
||||
|
||||
class BlurPool2d(nn.Module):
|
||||
r"""Creates a module that computes blurs and downsample a given feature map.
|
||||
See :cite:`zhang2019shiftinvar` for more details.
|
||||
Corresponds to the Downsample class, which does blurring and subsampling
|
||||
|
||||
Args:
|
||||
channels = Number of input channels
|
||||
filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
|
||||
stride (int): downsampling filter stride
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the transformed tensor.
|
||||
"""
|
||||
filt: Dict[str, torch.Tensor]
|
||||
|
||||
def __init__(self, channels, filt_size=3, stride=2) -> None:
|
||||
super(BlurPool2d, self).__init__()
|
||||
assert filt_size > 1
|
||||
self.channels = channels
|
||||
self.filt_size = filt_size
|
||||
self.stride = stride
|
||||
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4
|
||||
self.padding = nn.ReflectionPad2d(pad_size)
|
||||
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat
|
||||
self.filt = {} # lazy init by device for DataParallel compat
|
||||
|
||||
def _create_filter(self, like: torch.Tensor):
|
||||
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device)
|
||||
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1)
|
||||
|
||||
def _apply(self, fn):
|
||||
# override nn.Module _apply, reset filter cache if used
|
||||
self.filt = {}
|
||||
super(BlurPool2d, self)._apply(fn)
|
||||
|
||||
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
C = input_tensor.shape[1]
|
||||
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
|
||||
return F.conv2d(
|
||||
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)
|
@ -1,55 +0,0 @@
|
||||
"""
|
||||
BlurPool layer inspired by
|
||||
- Kornia's Max_BlurPool2d
|
||||
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
|
||||
|
||||
Hacked together by Chris Ha and Ross Wightman
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from .padding import get_padding
|
||||
|
||||
|
||||
class BlurPool2d(nn.Module):
|
||||
r"""Creates a module that computes blurs and downsample a given feature map.
|
||||
See :cite:`zhang2019shiftinvar` for more details.
|
||||
Corresponds to the Downsample class, which does blurring and subsampling
|
||||
Args:
|
||||
channels = Number of input channels
|
||||
blur_filter_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
|
||||
stride (int): downsampling filter stride
|
||||
Shape:
|
||||
Returns:
|
||||
torch.Tensor: the transformed tensor.
|
||||
Examples:
|
||||
"""
|
||||
|
||||
def __init__(self, channels, blur_filter_size=3, stride=2) -> None:
|
||||
super(BlurPool2d, self).__init__()
|
||||
assert blur_filter_size > 1
|
||||
self.channels = channels
|
||||
self.blur_filter_size = blur_filter_size
|
||||
self.stride = stride
|
||||
|
||||
pad_size = [get_padding(blur_filter_size, stride, dilation=1)] * 4
|
||||
self.padding = nn.ReflectionPad2d(pad_size)
|
||||
|
||||
blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs
|
||||
blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :])
|
||||
self.blur_filter = blur_filter[None, None, :, :]
|
||||
|
||||
def _apply(self, fn):
|
||||
# override nn.Module _apply to prevent need for blur_filter to be registered as a buffer,
|
||||
# this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected
|
||||
super(BlurPool2d, self)._apply(fn)
|
||||
self.blur_filter = fn(self.blur_filter)
|
||||
|
||||
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
C = input_tensor.shape[1]
|
||||
return F.conv2d(
|
||||
self.padding(input_tensor),
|
||||
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), stride=self.stride, groups=C)
|
Loading…
Reference in new issue