parent
98a7403ed4
commit
011a11d987
@ -0,0 +1,61 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.parallel
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class AntiAliasDownsampleLayer(nn.Module):
|
||||||
|
def __init__(self, remove_aa_jit: bool = False, filt_size: int = 3, stride: int = 2,
|
||||||
|
channels: int = 0):
|
||||||
|
super(AntiAliasDownsampleLayer, self).__init__()
|
||||||
|
if not remove_aa_jit:
|
||||||
|
self.op = DownsampleJIT(filt_size, stride, channels)
|
||||||
|
else:
|
||||||
|
self.op = Downsample(filt_size, stride, channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
class DownsampleJIT(object):
|
||||||
|
def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0):
|
||||||
|
self.stride = stride
|
||||||
|
self.filt_size = filt_size
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
assert self.filt_size == 3
|
||||||
|
assert stride == 2
|
||||||
|
a = torch.tensor([1., 2., 1.])
|
||||||
|
|
||||||
|
filt = (a[:, None] * a[None, :]).clone().detach()
|
||||||
|
filt = filt / torch.sum(filt)
|
||||||
|
self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half()
|
||||||
|
|
||||||
|
def __call__(self, input: torch.Tensor):
|
||||||
|
if input.dtype != self.filt.dtype:
|
||||||
|
self.filt = self.filt.float()
|
||||||
|
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
|
||||||
|
return F.conv2d(input_pad, self.filt, stride=2, padding=0, groups=input.shape[1])
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
def __init__(self, filt_size=3, stride=2, channels=None):
|
||||||
|
super(Downsample, self).__init__()
|
||||||
|
self.filt_size = filt_size
|
||||||
|
self.stride = stride
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
|
||||||
|
assert self.filt_size == 3
|
||||||
|
a = torch.tensor([1., 2., 1.])
|
||||||
|
|
||||||
|
filt = (a[:, None] * a[None, :])
|
||||||
|
filt = filt / torch.sum(filt)
|
||||||
|
|
||||||
|
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
|
||||||
|
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
|
||||||
|
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])
|
@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceToDepth(nn.Module):
|
||||||
|
def __init__(self, block_size=4):
|
||||||
|
super().__init__()
|
||||||
|
assert block_size == 4
|
||||||
|
self.bs = block_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
|
||||||
|
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
||||||
|
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
class SpaceToDepthJit(object):
|
||||||
|
def __call__(self, x: torch.Tensor):
|
||||||
|
# assuming hard-coded that block_size==4 for acceleration
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
|
||||||
|
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
||||||
|
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceToDepthModule(nn.Module):
|
||||||
|
def __init__(self, remove_model_jit=False):
|
||||||
|
super().__init__()
|
||||||
|
if not remove_model_jit:
|
||||||
|
self.op = SpaceToDepthJit()
|
||||||
|
else:
|
||||||
|
self.op = SpaceToDepth()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthToSpace(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, block_size):
|
||||||
|
super().__init__()
|
||||||
|
self.bs = block_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
|
||||||
|
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
|
||||||
|
x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
|
||||||
|
return x
|
@ -0,0 +1,292 @@
|
|||||||
|
"""
|
||||||
|
TResNet: High Performance GPU-Dedicated Architecture
|
||||||
|
https://arxiv.org/pdf/2003.13630.pdf
|
||||||
|
|
||||||
|
Original model: https://github.com/mrT23/TResNet
|
||||||
|
|
||||||
|
"""
|
||||||
|
from functools import partial
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from collections import OrderedDict
|
||||||
|
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer
|
||||||
|
from .registry import register_model
|
||||||
|
from .helpers import load_pretrained
|
||||||
|
|
||||||
|
try:
|
||||||
|
from inplace_abn import InPlaceABN
|
||||||
|
has_iabn = True
|
||||||
|
except ImportError:
|
||||||
|
has_iabn = False
|
||||||
|
|
||||||
|
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
|
'mean': (0, 0, 0), 'std': (1, 1, 1),
|
||||||
|
'first_conv': 'layer0.conv1', 'classifier': 'head',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
'tresnet_m':
|
||||||
|
_cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_m_80_8.pth'),
|
||||||
|
'tresnet_l':
|
||||||
|
_cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_l_81_5.pth'),
|
||||||
|
'tresnet_xl':
|
||||||
|
_cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_xl_82_0.pth')
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FastGlobalAvgPool2d(nn.Module):
|
||||||
|
def __init__(self, flatten=False):
|
||||||
|
super(FastGlobalAvgPool2d, self).__init__()
|
||||||
|
self.flatten = flatten
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.flatten:
|
||||||
|
in_size = x.size()
|
||||||
|
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
||||||
|
else:
|
||||||
|
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class FastSEModule(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, channels, reduction_channels, inplace=True):
|
||||||
|
super(FastSEModule, self).__init__()
|
||||||
|
self.avg_pool = FastGlobalAvgPool2d()
|
||||||
|
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
||||||
|
self.relu = nn.ReLU(inplace=inplace)
|
||||||
|
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
||||||
|
self.activation = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_se = self.avg_pool(x)
|
||||||
|
x_se2 = self.fc1(x_se)
|
||||||
|
x_se2 = self.relu(x_se2)
|
||||||
|
x_se = self.fc2(x_se2)
|
||||||
|
x_se = self.activation(x_se)
|
||||||
|
return x * x_se
|
||||||
|
|
||||||
|
|
||||||
|
def IABN2Float(module: nn.Module) -> nn.Module:
|
||||||
|
"If `module` is IABN don't use half precision."
|
||||||
|
if isinstance(module, InPlaceABN):
|
||||||
|
module.float()
|
||||||
|
for child in module.children(): IABN2Float(child)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def conv2d_ABN(ni, nf, stride, activation="leaky_relu", kernel_size=3, activation_param=1e-2, groups=1):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups,
|
||||||
|
bias=False),
|
||||||
|
InPlaceABN(num_features=nf, activation=activation, activation_param=activation_param)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
if stride == 1:
|
||||||
|
self.conv1 = conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3)
|
||||||
|
else:
|
||||||
|
if anti_alias_layer is None:
|
||||||
|
self.conv1 = conv2d_ABN(inplanes, planes, stride=2, activation_param=1e-3)
|
||||||
|
else:
|
||||||
|
self.conv1 = nn.Sequential(conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3),
|
||||||
|
anti_alias_layer(channels=planes, filt_size=3, stride=2))
|
||||||
|
|
||||||
|
self.conv2 = conv2d_ABN(planes, planes, stride=1, activation="identity")
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
reduce_layer_planes = max(planes * self.expansion // 4, 64)
|
||||||
|
self.se = FastSEModule(planes * self.expansion, reduce_layer_planes) if use_se else None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(x)
|
||||||
|
else:
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.conv2(out)
|
||||||
|
|
||||||
|
if self.se is not None: out = self.se(out)
|
||||||
|
|
||||||
|
out += residual
|
||||||
|
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
self.conv1 = conv2d_ABN(inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu",
|
||||||
|
activation_param=1e-3)
|
||||||
|
if stride == 1:
|
||||||
|
self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu",
|
||||||
|
activation_param=1e-3)
|
||||||
|
else:
|
||||||
|
if anti_alias_layer is None:
|
||||||
|
self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=2, activation="leaky_relu",
|
||||||
|
activation_param=1e-3)
|
||||||
|
else:
|
||||||
|
self.conv2 = nn.Sequential(conv2d_ABN(planes, planes, kernel_size=3, stride=1,
|
||||||
|
activation="leaky_relu", activation_param=1e-3),
|
||||||
|
anti_alias_layer(channels=planes, filt_size=3, stride=2))
|
||||||
|
|
||||||
|
self.conv3 = conv2d_ABN(planes, planes * self.expansion, kernel_size=1, stride=1,
|
||||||
|
activation="identity")
|
||||||
|
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
reduce_layer_planes = max(planes * self.expansion // 8, 64)
|
||||||
|
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(x)
|
||||||
|
else:
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.conv2(out)
|
||||||
|
if self.se is not None: out = self.se(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = out + residual # no inplace
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TResNet(nn.Module):
|
||||||
|
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, remove_aa_jit=False):
|
||||||
|
if not has_iabn:
|
||||||
|
raise " For TResNet models, please install InplaceABN: 'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11' "
|
||||||
|
|
||||||
|
super(TResNet, self).__init__()
|
||||||
|
|
||||||
|
# JIT layers
|
||||||
|
space_to_depth = SpaceToDepthModule()
|
||||||
|
anti_alias_layer = partial(AntiAliasDownsampleLayer, remove_aa_jit=remove_aa_jit)
|
||||||
|
global_pool_layer = FastGlobalAvgPool2d(flatten=True)
|
||||||
|
|
||||||
|
# TResnet stages
|
||||||
|
self.inplanes = int(64 * width_factor)
|
||||||
|
self.planes = int(64 * width_factor)
|
||||||
|
conv1 = conv2d_ABN(in_chans * 16, self.planes, stride=1, kernel_size=3)
|
||||||
|
layer1 = self._make_layer(BasicBlock, self.planes, layers[0], stride=1, use_se=True,
|
||||||
|
anti_alias_layer=anti_alias_layer) # 56x56
|
||||||
|
layer2 = self._make_layer(BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True,
|
||||||
|
anti_alias_layer=anti_alias_layer) # 28x28
|
||||||
|
layer3 = self._make_layer(Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True,
|
||||||
|
anti_alias_layer=anti_alias_layer) # 14x14
|
||||||
|
layer4 = self._make_layer(Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False,
|
||||||
|
anti_alias_layer=anti_alias_layer) # 7x7
|
||||||
|
|
||||||
|
# body
|
||||||
|
self.body = nn.Sequential(OrderedDict([
|
||||||
|
('SpaceToDepth', space_to_depth),
|
||||||
|
('conv1', conv1),
|
||||||
|
('layer1', layer1),
|
||||||
|
('layer2', layer2),
|
||||||
|
('layer3', layer3),
|
||||||
|
('layer4', layer4)]))
|
||||||
|
|
||||||
|
# head
|
||||||
|
self.embeddings = []
|
||||||
|
self.global_pool = nn.Sequential(OrderedDict([('global_pool_layer', global_pool_layer)]))
|
||||||
|
self.num_features = (self.planes * 8) * Bottleneck.expansion
|
||||||
|
fc = nn.Linear(self.num_features, num_classes)
|
||||||
|
self.head = nn.Sequential(OrderedDict([('fc', fc)]))
|
||||||
|
|
||||||
|
# model initilization
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||||
|
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InPlaceABN):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
# residual connections special initialization
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, BasicBlock):
|
||||||
|
m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero
|
||||||
|
if isinstance(m, Bottleneck):
|
||||||
|
m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero
|
||||||
|
if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1, use_se=True, anti_alias_layer=None):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
layers = []
|
||||||
|
if stride == 2:
|
||||||
|
# avg pooling before 1x1 conv
|
||||||
|
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
|
||||||
|
layers += [conv2d_ABN(self.inplanes, planes * block.expansion, kernel_size=1, stride=1,
|
||||||
|
activation="identity")]
|
||||||
|
downsample = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample, use_se=use_se,
|
||||||
|
anti_alias_layer=anti_alias_layer))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for i in range(1, blocks): layers.append(
|
||||||
|
block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.body(x)
|
||||||
|
self.embeddings = self.global_pool(x)
|
||||||
|
logits = self.head(self.embeddings)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def filter_fn(input):
|
||||||
|
return input['model']
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def tresnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
default_cfg = default_cfgs['tresnet_m']
|
||||||
|
model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def tresnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
default_cfg = default_cfgs['tresnet_l']
|
||||||
|
model = TResNet(layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def tresnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
default_cfg = default_cfgs['tresnet_xl']
|
||||||
|
model = TResNet(layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn)
|
||||||
|
return model
|
Loading…
Reference in new issue