You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/layers/anti_aliasing.py

62 lines
2.3 KiB

5 years ago
import torch
import torch.nn.parallel
import torch.nn as nn
import torch.nn.functional as F
class AntiAliasDownsampleLayer(nn.Module):
def __init__(self, no_jit: bool = False, filt_size: int = 3, stride: int = 2, channels: int = 0):
5 years ago
super(AntiAliasDownsampleLayer, self).__init__()
if no_jit:
5 years ago
self.op = Downsample(filt_size, stride, channels)
else:
self.op = DownsampleJIT(filt_size, stride, channels)
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
5 years ago
def forward(self, x):
return self.op(x)
@torch.jit.script
class DownsampleJIT(object):
def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0):
self.stride = stride
self.filt_size = filt_size
self.channels = channels
assert self.filt_size == 3
assert stride == 2
self.filt = {} # lazy init by device for DataParallel compat
5 years ago
def _create_filter(self, like: torch.Tensor):
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
filt = filt[:, None] * filt[None, :]
5 years ago
filt = filt / torch.sum(filt)
filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
return filt
5 years ago
def __call__(self, input: torch.Tensor):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
filt = self.filt.get(str(input.device), self._create_filter(input))
return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1])
5 years ago
class Downsample(nn.Module):
def __init__(self, filt_size=3, stride=2, channels=None):
super(Downsample, self).__init__()
self.filt_size = filt_size
self.stride = stride
self.channels = channels
assert self.filt_size == 3
filt = torch.tensor([1., 2., 1.])
filt = filt[:, None] * filt[None, :]
5 years ago
filt = filt / torch.sum(filt)
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
def forward(self, input):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])