|
|
@ -14,7 +14,7 @@ import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead
|
|
|
|
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule
|
|
|
|
from .registry import register_model
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
|
|
|
|
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
|
|
|
@ -49,40 +49,6 @@ default_cfgs = {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastGlobalAvgPool2d(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, flatten=False):
|
|
|
|
|
|
|
|
super(FastGlobalAvgPool2d, self).__init__()
|
|
|
|
|
|
|
|
self.flatten = flatten
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
if self.flatten:
|
|
|
|
|
|
|
|
return x.mean((2, 3))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return x.mean((2, 3), keepdim=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def feat_mult(self):
|
|
|
|
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastSEModule(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, channels, reduction_channels, inplace=True):
|
|
|
|
|
|
|
|
super(FastSEModule, self).__init__()
|
|
|
|
|
|
|
|
self.avg_pool = FastGlobalAvgPool2d()
|
|
|
|
|
|
|
|
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
|
|
|
|
|
|
|
self.relu = nn.ReLU(inplace=inplace)
|
|
|
|
|
|
|
|
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
|
|
|
|
|
|
|
self.activation = nn.Sigmoid()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
x_se = self.avg_pool(x)
|
|
|
|
|
|
|
|
x_se2 = self.fc1(x_se)
|
|
|
|
|
|
|
|
x_se2 = self.relu(x_se2)
|
|
|
|
|
|
|
|
x_se = self.fc2(x_se2)
|
|
|
|
|
|
|
|
x_se = self.activation(x_se)
|
|
|
|
|
|
|
|
return x * x_se
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def IABN2Float(module: nn.Module) -> nn.Module:
|
|
|
|
def IABN2Float(module: nn.Module) -> nn.Module:
|
|
|
|
"""If `module` is IABN don't use half precision."""
|
|
|
|
"""If `module` is IABN don't use half precision."""
|
|
|
|
if isinstance(module, InplaceAbn):
|
|
|
|
if isinstance(module, InplaceAbn):
|
|
|
@ -119,8 +85,8 @@ class BasicBlock(nn.Module):
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
self.downsample = downsample
|
|
|
|
self.downsample = downsample
|
|
|
|
self.stride = stride
|
|
|
|
self.stride = stride
|
|
|
|
reduce_layer_planes = max(planes * self.expansion // 4, 64)
|
|
|
|
reduction_chs = max(planes * self.expansion // 4, 64)
|
|
|
|
self.se = FastSEModule(planes * self.expansion, reduce_layer_planes) if use_se else None
|
|
|
|
self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
if self.downsample is not None:
|
|
|
|
if self.downsample is not None:
|
|
|
@ -159,8 +125,8 @@ class Bottleneck(nn.Module):
|
|
|
|
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
|
|
|
|
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
|
|
|
|
aa_layer(channels=planes, filt_size=3, stride=2))
|
|
|
|
aa_layer(channels=planes, filt_size=3, stride=2))
|
|
|
|
|
|
|
|
|
|
|
|
reduce_layer_planes = max(planes * self.expansion // 8, 64)
|
|
|
|
reduction_chs = max(planes * self.expansion // 8, 64)
|
|
|
|
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None
|
|
|
|
self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None
|
|
|
|
|
|
|
|
|
|
|
|
self.conv3 = conv2d_iabn(
|
|
|
|
self.conv3 = conv2d_iabn(
|
|
|
|
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
|
|
|
|
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
|
|
|
@ -189,7 +155,7 @@ class Bottleneck(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
class TResNet(nn.Module):
|
|
|
|
class TResNet(nn.Module):
|
|
|
|
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
|
|
|
|
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
|
|
|
|
global_pool='avg', drop_rate=0.):
|
|
|
|
global_pool='fast', drop_rate=0.):
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
super(TResNet, self).__init__()
|
|
|
|
super(TResNet, self).__init__()
|
|
|
@ -272,7 +238,7 @@ class TResNet(nn.Module):
|
|
|
|
def get_classifier(self):
|
|
|
|
def get_classifier(self):
|
|
|
|
return self.head.fc
|
|
|
|
return self.head.fc
|
|
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
|
|
def reset_classifier(self, num_classes, global_pool='fast'):
|
|
|
|
self.head = ClassifierHead(
|
|
|
|
self.head = ClassifierHead(
|
|
|
|
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
|
|