@ -8,8 +8,9 @@ Original model: https://github.com/mrT23/TResNet
from functools import partial
import torch
import torch . nn as nn
import torch . nn . functional as F
from collections import OrderedDict
from . layers import SpaceToDepthModule , AntiAliasDownsampleLayer
from . layers import SpaceToDepthModule , AntiAliasDownsampleLayer , SelectAdaptivePool2d
from . registry import register_model
from . helpers import load_pretrained
@ -27,18 +28,27 @@ def _cfg(url='', **kwargs):
' url ' : url , ' num_classes ' : 1000 , ' input_size ' : ( 3 , 224 , 224 ) , ' pool_size ' : ( 7 , 7 ) ,
' crop_pct ' : 0.875 , ' interpolation ' : ' bilinear ' ,
' mean ' : ( 0 , 0 , 0 ) , ' std ' : ( 1 , 1 , 1 ) ,
' first_conv ' : ' layer0.conv1 ' , ' classifier ' : ' head ' ,
' first_conv ' : ' layer0.conv1 ' , ' classifier ' : ' head .fc ' ,
* * kwargs
}
default_cfgs = {
' tresnet_m ' :
_cfg ( url = ' https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_m_80_8.pth ' ) ,
' tresnet_l ' :
_cfg ( url = ' https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_l_81_5.pth ' ) ,
' tresnet_xl ' :
_cfg ( url = ' https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_xl_82_0.pth ' )
' tresnet_m ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_80_8-dbc13962.pth ' ) ,
' tresnet_l ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth ' ) ,
' tresnet_xl ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth ' ) ,
' tresnet_m_448 ' : _cfg (
input_size = ( 3 , 448 , 448 ) ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth ' ) ,
' tresnet_l_448 ' : _cfg (
input_size = ( 3 , 448 , 448 ) ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth ' ) ,
' tresnet_xl_448 ' : _cfg (
input_size = ( 3 , 448 , 448 ) ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth ' )
}
@ -54,6 +64,9 @@ class FastGlobalAvgPool2d(nn.Module):
else :
return x . view ( x . size ( 0 ) , x . size ( 1 ) , - 1 ) . mean ( - 1 ) . view ( x . size ( 0 ) , x . size ( 1 ) , 1 , 1 )
def feat_mult ( self ) :
return 1
class FastSEModule ( nn . Module ) :
@ -78,14 +91,15 @@ def IABN2Float(module: nn.Module) -> nn.Module:
" If `module` is IABN don ' t use half precision. "
if isinstance ( module , InPlaceABN ) :
module . float ( )
for child in module . children ( ) : IABN2Float ( child )
for child in module . children ( ) :
IABN2Float ( child )
return module
def conv2d_ABN ( ni , nf , stride , activation = " leaky_relu " , kernel_size = 3 , activation_param = 1e-2 , groups = 1 ) :
return nn . Sequential (
nn . Conv2d ( ni , nf , kernel_size = kernel_size , stride = stride , padding = kernel_size / / 2 , groups = groups ,
bias = False ) ,
nn . Conv2d (
ni , nf , kernel_size = kernel_size , stride = stride , padding = kernel_size / / 2 , groups = groups , bias = False ) ,
InPlaceABN ( num_features = nf , activation = activation , activation_param = activation_param )
)
@ -101,8 +115,9 @@ class BasicBlock(nn.Module):
if anti_alias_layer is None :
self . conv1 = conv2d_ABN ( inplanes , planes , stride = 2 , activation_param = 1e-3 )
else :
self . conv1 = nn . Sequential ( conv2d_ABN ( inplanes , planes , stride = 1 , activation_param = 1e-3 ) ,
anti_alias_layer ( channels = planes , filt_size = 3 , stride = 2 ) )
self . conv1 = nn . Sequential (
conv2d_ABN ( inplanes , planes , stride = 1 , activation_param = 1e-3 ) ,
anti_alias_layer ( channels = planes , filt_size = 3 , stride = 2 ) )
self . conv2 = conv2d_ABN ( planes , planes , stride = 1 , activation = " identity " )
self . relu = nn . ReLU ( inplace = True )
@ -120,12 +135,11 @@ class BasicBlock(nn.Module):
out = self . conv1 ( x )
out = self . conv2 ( out )
if self . se is not None : out = self . se ( out )
if self . se is not None :
out = self . se ( out )
out + = residual
out = self . relu ( out )
return out
@ -134,22 +148,22 @@ class Bottleneck(nn.Module):
def __init__ ( self , inplanes , planes , stride = 1 , downsample = None , use_se = True , anti_alias_layer = None ) :
super ( Bottleneck , self ) . __init__ ( )
self . conv1 = conv2d_ABN ( inplanes , planes , kernel_size = 1 , stride = 1 , activation = " leaky_relu " ,
activation_param = 1e-3 )
self . conv1 = conv2d_ABN (
inplanes , planes , kernel_size = 1 , stride = 1 , activation = " leaky_relu " , activation_param = 1e-3 )
if stride == 1 :
self . conv2 = conv2d_ABN ( planes , planes , kernel_size = 3 , stride = 1 , activation = " leaky_relu " ,
activation_param = 1e-3 )
self . conv2 = conv2d_ABN (
planes , planes , kernel_size = 3 , stride = 1 , activation = " leaky_relu " , activation_param = 1e-3 )
else :
if anti_alias_layer is None :
self . conv2 = conv2d_ABN ( planes , planes , kernel_size = 3 , stride = 2 , activation = " leaky_relu " ,
activation_param = 1e-3 )
self . conv2 = conv2d_ABN (
planes , planes , kernel_size = 3 , stride = 2 , activation = " leaky_relu " , activation_param = 1e-3 )
else :
self . conv2 = nn . Sequential ( conv2d_ABN ( planes , planes , kernel_size = 3 , stride = 1 ,
activation = " leaky_relu " , activation_param = 1e-3 ) ,
anti_alias_layer ( channels = planes , filt_size = 3 , stride = 2 ) )
self . conv2 = nn . Sequential (
conv2d_ABN ( planes , planes , kernel_size = 3 , stride = 1 , activation = " leaky_relu " , activation_param = 1e-3 ) ,
anti_alias_layer ( channels = planes , filt_size = 3 , stride = 2 ) )
self . conv3 = conv2d_ABN ( planes , planes * self . expansion , kernel_size = 1 , stride = 1 ,
activation = " identity " )
self . conv3 = conv2d_ABN (
planes , planes * self . expansion , kernel_size = 1 , stride = 1 , activation = " identity " )
self . relu = nn . ReLU ( inplace = True )
self . downsample = downsample
@ -166,7 +180,8 @@ class Bottleneck(nn.Module):
out = self . conv1 ( x )
out = self . conv2 ( out )
if self . se is not None : out = self . se ( out )
if self . se is not None :
out = self . se ( out )
out = self . conv3 ( out )
out = out + residual # no inplace
@ -176,29 +191,32 @@ class Bottleneck(nn.Module):
class TResNet ( nn . Module ) :
def __init__ ( self , layers , in_chans = 3 , num_classes = 1000 , width_factor = 1.0 , remove_aa_jit = False ) :
def __init__ ( self , layers , in_chans = 3 , num_classes = 1000 , width_factor = 1.0 , no_aa_jit = False ,
global_pool = ' avg ' , drop_rate = 0. ) :
if not has_iabn :
raise " For TResNet models, please install InplaceABN: ' pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11 ' "
raise ImportError (
" For TResNet models, please install InplaceABN: "
" ' pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11 ' " )
self . num_classes = num_classes
self . drop_rate = drop_rate
super ( TResNet , self ) . __init__ ( )
# JIT layers
space_to_depth = SpaceToDepthModule ( )
anti_alias_layer = partial ( AntiAliasDownsampleLayer , remove_aa_jit = remove_aa_jit )
global_pool_layer = FastGlobalAvgPool2d ( flatten = True )
anti_alias_layer = partial ( AntiAliasDownsampleLayer , no_jit = no_aa_jit )
# TResnet stages
self . inplanes = int ( 64 * width_factor )
self . planes = int ( 64 * width_factor )
conv1 = conv2d_ABN ( in_chans * 16 , self . planes , stride = 1 , kernel_size = 3 )
layer1 = self . _make_layer ( BasicBlock , self . planes , layers [ 0 ] , stride = 1 , use_se = True ,
anti_alias_layer = anti_alias_layer ) # 56x56
layer2 = self . _make_layer ( BasicBlock , self . planes * 2 , layers [ 1 ] , stride = 2 , use_se = True ,
anti_alias_layer = anti_alias_layer ) # 28x28
layer3 = self . _make_layer ( Bottleneck , self . planes * 4 , layers [ 2 ] , stride = 2 , use_se = True ,
anti_alias_layer = anti_alias_layer ) # 14x14
layer4 = self . _make_layer ( Bottleneck , self . planes * 8 , layers [ 3 ] , stride = 2 , use_se = False ,
anti_alias_layer = anti_alias_layer ) # 7x7
layer1 = self . _make_layer (
BasicBlock , self . planes , layers [ 0 ] , stride = 1 , use_se = True , anti_alias_layer = anti_alias_layer ) # 56x56
layer2 = self . _make_layer (
BasicBlock , self . planes * 2 , layers [ 1 ] , stride = 2 , use_se = True , anti_alias_layer = anti_alias_layer ) # 28x28
layer3 = self . _make_layer (
Bottleneck , self . planes * 4 , layers [ 2 ] , stride = 2 , use_se = True , anti_alias_layer = anti_alias_layer ) # 14x14
layer4 = self . _make_layer (
Bottleneck , self . planes * 8 , layers [ 3 ] , stride = 2 , use_se = False , anti_alias_layer = anti_alias_layer ) # 7x7
# body
self . body = nn . Sequential ( OrderedDict ( [
@ -210,11 +228,10 @@ class TResNet(nn.Module):
( ' layer4 ' , layer4 ) ] ) )
# head
self . embeddings = [ ]
self . global_pool = nn . Sequential ( OrderedDict ( [ ( ' global_pool_layer ' , global_pool_layer ) ] ) )
self . num_features = ( self . planes * 8 ) * Bottleneck . expansion
fc = nn . Linear ( self . num_features , num_classes )
self . head = nn . Sequential ( OrderedDict ( [ ( ' fc ' , fc ) ] ) )
self . global_pool = None
self . head = None
self . reset_classifier ( num_classes , global_pool )
# model initilization
for m in self . modules ( ) :
@ -239,54 +256,104 @@ class TResNet(nn.Module):
if stride == 2 :
# avg pooling before 1x1 conv
layers . append ( nn . AvgPool2d ( kernel_size = 2 , stride = 2 , ceil_mode = True , count_include_pad = False ) )
layers + = [ conv2d_ABN ( self . inplanes , planes * block . expansion , kernel_size = 1 , stride = 1 ,
activation = " identity " ) ]
layers + = [ conv2d_ABN (
self . inplanes , planes * block . expansion , kernel_size = 1 , stride = 1 , activation = " identity " ) ]
downsample = nn . Sequential ( * layers )
layers = [ ]
layers . append ( block ( self . inplanes , planes , stride , downsample , use_se = use_se ,
anti_alias_layer = anti_alias_layer ) )
layers . append ( block (
self . inplanes , planes , stride , downsample , use_se = use_se , anti_alias_layer = anti_alias_layer ) )
self . inplanes = planes * block . expansion
for i in range ( 1 , blocks ) : layers . append (
block ( self . inplanes , planes , use_se = use_se , anti_alias_layer = anti_alias_layer ) )
for i in range ( 1 , blocks ) :
layers . append (
block ( self . inplanes , planes , use_se = use_se , anti_alias_layer = anti_alias_layer ) )
return nn . Sequential ( * layers )
def forward ( self , x ) :
x = self . body ( x )
self . embeddings = self . global_pool ( x )
logits = self . head ( self . embeddings )
return logits
def get_classifier ( self ) :
return self . head . fc
def reset_classifier ( self , num_classes , global_pool = ' avg ' ) :
self . num_classes = num_classes
if global_pool == ' avg ' :
self . global_pool = FastGlobalAvgPool2d ( flatten = True )
else :
self . global_pool = SelectAdaptivePool2d ( pool_type = global_pool , flatten = True )
self . head = None
if num_classes :
self . head = nn . Sequential ( OrderedDict ( [
( ' fc ' , nn . Linear ( self . num_features * self . global_pool . feat_mult ( ) , num_classes ) ) ] ) )
def filter_fn ( input ) :
return input [ ' model ' ]
def forward_features ( self , x ) :
return self . body ( x )
def forward ( self , x ) :
x = self . forward_features ( x )
x = self . global_pool ( x )
if self . drop_rate :
x = F . dropout ( x , p = float ( self . drop_rate ) , training = self . training )
x = self . head ( x )
return x
@register_model
def tresnet_m ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
default_cfg = default_cfgs [ ' tresnet_m ' ]
model = TResNet ( layers = [ 3 , 4 , 11 , 3 ] , num_classes = num_classes , in_chans = in_chans )
model = TResNet ( layers = [ 3 , 4 , 11 , 3 ] , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans , filter_fn = filter_fn )
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def tresnet_l ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
default_cfg = default_cfgs [ ' tresnet_l ' ]
model = TResNet ( layers = [ 4 , 5 , 18 , 3 ] , num_classes = num_classes , in_chans = in_chans , width_factor = 1.2 )
model = TResNet (
layers = [ 4 , 5 , 18 , 3 ] , num_classes = num_classes , in_chans = in_chans , width_factor = 1.2 , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans , filter_fn = filter_fn )
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def tresnet_xl ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
default_cfg = default_cfgs [ ' tresnet_xl ' ]
model = TResNet ( layers = [ 4 , 5 , 24 , 3 ] , num_classes = num_classes , in_chans = in_chans , width_factor = 1.3 )
model = TResNet (
layers = [ 4 , 5 , 24 , 3 ] , num_classes = num_classes , in_chans = in_chans , width_factor = 1.3 , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def tresnet_m_448 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
default_cfg = default_cfgs [ ' tresnet_m_448 ' ]
model = TResNet ( layers = [ 3 , 4 , 11 , 3 ] , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def tresnet_l_448 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
default_cfg = default_cfgs [ ' tresnet_l_448 ' ]
model = TResNet (
layers = [ 4 , 5 , 18 , 3 ] , num_classes = num_classes , in_chans = in_chans , width_factor = 1.2 , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def tresnet_xl_448 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
default_cfg = default_cfgs [ ' tresnet_xl_448 ' ]
model = TResNet (
layers = [ 4 , 5 , 24 , 3 ] , num_classes = num_classes , in_chans = in_chans , width_factor = 1.3 , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans , filter_fn = filter_fn )
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model