diff --git a/timm/models/layers/anti_aliasing.py b/timm/models/layers/anti_aliasing.py index fd6457bf..38f96ee3 100644 --- a/timm/models/layers/anti_aliasing.py +++ b/timm/models/layers/anti_aliasing.py @@ -1,61 +1,61 @@ -import torch -import torch.nn.parallel -import torch.nn as nn -import torch.nn.functional as F - - -class AntiAliasDownsampleLayer(nn.Module): - def __init__(self, no_jit: bool = False, filt_size: int = 3, stride: int = 2, channels: int = 0): - super(AntiAliasDownsampleLayer, self).__init__() - if no_jit: - self.op = Downsample(filt_size, stride, channels) - else: - self.op = DownsampleJIT(filt_size, stride, channels) - - # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls - - def forward(self, x): - return self.op(x) - - -@torch.jit.script -class DownsampleJIT(object): - def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0): - self.stride = stride - self.filt_size = filt_size - self.channels = channels - assert self.filt_size == 3 - assert stride == 2 - self.filt = {} # lazy init by device for DataParallel compat - - def _create_filter(self, like: torch.Tensor): - filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) - filt = filt[:, None] * filt[None, :] - filt = filt / torch.sum(filt) - filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) - return filt - - def __call__(self, input: torch.Tensor): - input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') - filt = self.filt.get(str(input.device), self._create_filter(input)) - return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1]) - - -class Downsample(nn.Module): - def __init__(self, filt_size=3, stride=2, channels=None): - super(Downsample, self).__init__() - self.filt_size = filt_size - self.stride = stride - self.channels = channels - - assert self.filt_size == 3 - filt = torch.tensor([1., 2., 1.]) - filt = filt[:, None] * filt[None, :] - filt = filt / torch.sum(filt) - - # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) - self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) - - def forward(self, input): - input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') - return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) +import torch +import torch.nn.parallel +import torch.nn as nn +import torch.nn.functional as F + + +class AntiAliasDownsampleLayer(nn.Module): + def __init__(self, no_jit: bool = False, filt_size: int = 3, stride: int = 2, channels: int = 0): + super(AntiAliasDownsampleLayer, self).__init__() + if no_jit: + self.op = Downsample(filt_size, stride, channels) + else: + self.op = DownsampleJIT(filt_size, stride, channels) + + # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls + + def forward(self, x): + return self.op(x) + + +@torch.jit.script +class DownsampleJIT(object): + def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0): + self.stride = stride + self.filt_size = filt_size + self.channels = channels + assert self.filt_size == 3 + assert stride == 2 + self.filt = {} # lazy init by device for DataParallel compat + + def _create_filter(self, like: torch.Tensor): + filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) + filt = filt[:, None] * filt[None, :] + filt = filt / torch.sum(filt) + filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) + return filt + + def __call__(self, input: torch.Tensor): + input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') + filt = self.filt.get(str(input.device), self._create_filter(input)) + return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1]) + + +class Downsample(nn.Module): + def __init__(self, filt_size=3, stride=2, channels=None): + super(Downsample, self).__init__() + self.filt_size = filt_size + self.stride = stride + self.channels = channels + + assert self.filt_size == 3 + filt = torch.tensor([1., 2., 1.]) + filt = filt[:, None] * filt[None, :] + filt = filt / torch.sum(filt) + + # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) + self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) + + def forward(self, input): + input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') + return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) diff --git a/timm/models/layers/space_to_depth.py b/timm/models/layers/space_to_depth.py index 2c378fe1..a7e8e0b2 100644 --- a/timm/models/layers/space_to_depth.py +++ b/timm/models/layers/space_to_depth.py @@ -1,53 +1,53 @@ -import torch -import torch.nn as nn - - -class SpaceToDepth(nn.Module): - def __init__(self, block_size=4): - super().__init__() - assert block_size == 4 - self.bs = block_size - - def forward(self, x): - N, C, H, W = x.size() - x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) - x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) - x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) - return x - - -@torch.jit.script -class SpaceToDepthJit(object): - def __call__(self, x: torch.Tensor): - # assuming hard-coded that block_size==4 for acceleration - N, C, H, W = x.size() - x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) - x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) - x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) - return x - - -class SpaceToDepthModule(nn.Module): - def __init__(self, no_jit=False): - super().__init__() - if not no_jit: - self.op = SpaceToDepthJit() - else: - self.op = SpaceToDepth() - - def forward(self, x): - return self.op(x) - - -class DepthToSpace(nn.Module): - - def __init__(self, block_size): - super().__init__() - self.bs = block_size - - def forward(self, x): - N, C, H, W = x.size() - x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) - x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) - x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) - return x \ No newline at end of file +import torch +import torch.nn as nn + + +class SpaceToDepth(nn.Module): + def __init__(self, block_size=4): + super().__init__() + assert block_size == 4 + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) + return x + + +@torch.jit.script +class SpaceToDepthJit(object): + def __call__(self, x: torch.Tensor): + # assuming hard-coded that block_size==4 for acceleration + N, C, H, W = x.size() + x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) + return x + + +class SpaceToDepthModule(nn.Module): + def __init__(self, no_jit=False): + super().__init__() + if not no_jit: + self.op = SpaceToDepthJit() + else: + self.op = SpaceToDepth() + + def forward(self, x): + return self.op(x) + + +class DepthToSpace(nn.Module): + + def __init__(self, block_size): + super().__init__() + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) + x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) + return x