@ -29,9 +29,9 @@ def _cfg(url='', **kwargs):
default_cfgs = {
' tresnet_m ' : _cfg (
url = ' https:// miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/tresnet_m_1k_miil_83_1 .pth' ) ,
url = ' https:// github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_1k_miil_83_1-d236afcb .pth' ) ,
' tresnet_m_miil_in21k ' : _cfg (
url = ' https:// miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/tresnet_m_miil_in21k .pth' , num_classes = 11221 ) ,
url = ' https:// github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_miil_in21k-901b6ed4 .pth' , num_classes = 11221 ) ,
' tresnet_l ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth ' ) ,
' tresnet_xl ' : _cfg (
@ -44,7 +44,10 @@ default_cfgs = {
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth ' ) ,
' tresnet_xl_448 ' : _cfg (
input_size = ( 3 , 448 , 448 ) , pool_size = ( 14 , 14 ) ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth ' )
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth ' ) ,
' tresnet_v2_l ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_v2_83_9-f36e4445.pth ' ) ,
}
@ -99,7 +102,7 @@ class BasicBlock(nn.Module):
if self . se is not None :
out = self . se ( out )
out += shortcut
out = out + shortcut
out = self . relu ( out )
return out
@ -153,7 +156,16 @@ class Bottleneck(nn.Module):
class TResNet ( nn . Module ) :
def __init__ ( self , layers , in_chans = 3 , num_classes = 1000 , width_factor = 1.0 , global_pool = ' fast ' , drop_rate = 0. ) :
def __init__ (
self ,
layers ,
in_chans = 3 ,
num_classes = 1000 ,
width_factor = 1.0 ,
v2 = False ,
global_pool = ' fast ' ,
drop_rate = 0. ,
) :
self . num_classes = num_classes
self . drop_rate = drop_rate
super ( TResNet , self ) . __init__ ( )
@ -163,15 +175,19 @@ class TResNet(nn.Module):
# TResnet stages
self . inplanes = int ( 64 * width_factor )
self . planes = int ( 64 * width_factor )
if v2 :
self . inplanes = self . inplanes / / 8 * 8
self . planes = self . planes / / 8 * 8
conv1 = conv2d_iabn ( in_chans * 16 , self . planes , stride = 1 , kernel_size = 3 )
layer1 = self . _make_layer (
BasicBlock , self . planes , layers [ 0 ] , stride = 1 , use_se = True , aa_layer = aa_layer ) # 56x56
B ottleneck if v2 else B asicBlock, self . planes , layers [ 0 ] , stride = 1 , use_se = True , aa_layer = aa_layer )
layer2 = self . _make_layer (
B asicBlock, self . planes * 2 , layers [ 1 ] , stride = 2 , use_se = True , aa_layer = aa_layer ) # 28x28
B ottleneck if v2 else B asicBlock, self . planes * 2 , layers [ 1 ] , stride = 2 , use_se = True , aa_layer = aa_layer )
layer3 = self . _make_layer (
Bottleneck , self . planes * 4 , layers [ 2 ] , stride = 2 , use_se = True , aa_layer = aa_layer ) # 14x14
Bottleneck , self . planes * 4 , layers [ 2 ] , stride = 2 , use_se = True , aa_layer = aa_layer )
layer4 = self . _make_layer (
Bottleneck , self . planes * 8 , layers [ 3 ] , stride = 2 , use_se = False , aa_layer = aa_layer ) # 7x7
Bottleneck , self . planes * 8 , layers [ 3 ] , stride = 2 , use_se = False , aa_layer = aa_layer )
# body
self . body = nn . Sequential ( OrderedDict ( [
@ -285,6 +301,12 @@ def tresnet_l(pretrained=False, **kwargs):
return _create_tresnet ( ' tresnet_l ' , pretrained = pretrained , * * model_kwargs )
@register_model
def tresnet_v2_l ( pretrained = False , * * kwargs ) :
model_kwargs = dict ( layers = [ 3 , 4 , 23 , 3 ] , width_factor = 1.0 , v2 = True , * * kwargs )
return _create_tresnet ( ' tresnet_v2_l ' , pretrained = pretrained , * * model_kwargs )
@register_model
def tresnet_xl ( pretrained = False , * * kwargs ) :
model_kwargs = dict ( layers = [ 4 , 5 , 24 , 3 ] , width_factor = 1.3 , * * kwargs )