More models supporting feature extraction, xception, gluon_xception, inception_v3, inception_v4, pnasnet, nasnet, dla. Fix DLA unused projection params.

pull/175/head
Ross Wightman 4 years ago
parent 298fba09ac
commit 9eba134d79

@ -37,8 +37,7 @@ def test_model_forward(model_name, batch_size):
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
# DLA models have an issue TBD, add them to exclusions @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + ['dla*']))
@pytest.mark.parametrize('batch_size', [2]) @pytest.mark.parametrize('batch_size', [2])
def test_model_backward(model_name, batch_size): def test_model_backward(model_name, batch_size):
"""Run a single forward pass with each model""" """Run a single forward pass with each model"""

@ -12,7 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from .registry import register_model from .registry import register_model
@ -212,10 +212,19 @@ class DlaTree(nn.Module):
root_dim = 2 * out_channels root_dim = 2 * out_channels
if level_root: if level_root:
root_dim += in_channels root_dim += in_channels
self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else nn.Identity()
self.project = nn.Identity()
cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width) cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width)
if levels == 1: if levels == 1:
self.tree1 = block(in_channels, out_channels, stride, **cargs) self.tree1 = block(in_channels, out_channels, stride, **cargs)
self.tree2 = block(out_channels, out_channels, 1, **cargs) self.tree2 = block(out_channels, out_channels, 1, **cargs)
if in_channels != out_channels:
# NOTE the official impl/weights have project layers in levels > 1 case that are never
# used, I've moved the project layer here to avoid wasted params but old checkpoints will
# need strict=False while loading.
self.project = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels))
else: else:
cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual)) cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual))
self.tree1 = DlaTree( self.tree1 = DlaTree(
@ -226,22 +235,12 @@ class DlaTree(nn.Module):
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual) self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual)
self.level_root = level_root self.level_root = level_root
self.root_dim = root_dim self.root_dim = root_dim
self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else None
self.project = None
if in_channels != out_channels:
self.project = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels)
)
self.levels = levels self.levels = levels
def forward(self, x, residual=None, children=None): def forward(self, x, residual=None, children=None):
children = [] if children is None else children children = [] if children is None else children
# FIXME the way downsample / project are used here and residual is passed to next level up bottom = self.downsample(x)
# the tree, the residual is overridden and some project weights are thus never used and residual = self.project(bottom)
# have no gradients. This appears to be an issue with the original model / weights.
bottom = self.downsample(x) if self.downsample is not None else x
residual = self.project(bottom) if self.project is not None else bottom
if self.level_root: if self.level_root:
children.append(bottom) children.append(bottom)
x1 = self.tree1(x, residual) x1 = self.tree1(x, residual)
@ -255,8 +254,8 @@ class DlaTree(nn.Module):
class DLA(nn.Module): class DLA(nn.Module):
def __init__(self, levels, channels, num_classes=1000, in_chans=3, cardinality=1, base_width=64, def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3,
block=DlaBottle2neck, residual_root=False, linear_root=False, cardinality=1, base_width=64, block=DlaBottle2neck, residual_root=False,
drop_rate=0.0, global_pool='avg'): drop_rate=0.0, global_pool='avg'):
super(DLA, self).__init__() super(DLA, self).__init__()
self.channels = channels self.channels = channels
@ -264,6 +263,7 @@ class DLA(nn.Module):
self.cardinality = cardinality self.cardinality = cardinality
self.base_width = base_width self.base_width = base_width
self.drop_rate = drop_rate self.drop_rate = drop_rate
assert output_stride == 32 # FIXME support dilation
self.base_layer = nn.Sequential( self.base_layer = nn.Sequential(
nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False), nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
@ -276,6 +276,14 @@ class DLA(nn.Module):
self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs) self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs) self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs) self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs)
self.feature_info = [
dict(num_chs=channels[0], reduction=1, module='level0'), # rare to have a meaningful stride 1 level
dict(num_chs=channels[1], reduction=2, module='level1'),
dict(num_chs=channels[2], reduction=4, module='level2'),
dict(num_chs=channels[3], reduction=8, module='level3'),
dict(num_chs=channels[4], reduction=16, module='level4'),
dict(num_chs=channels[5], reduction=32, module='level5'),
]
self.num_features = channels[-1] self.num_features = channels[-1]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -331,142 +339,103 @@ class DLA(nn.Module):
return x.flatten(1) return x.flatten(1)
def _create_dla(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
DLA, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=False, feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), **kwargs)
@register_model @register_model
def dla60_res2net(pretrained=None, num_classes=1000, in_chans=3, **kwargs): def dla60_res2net(pretrained=False, **kwargs):
default_cfg = default_cfgs['dla60_res2net'] model_kwargs = dict(
model = DLA(levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
block=DlaBottle2neck, cardinality=1, base_width=28, block=DlaBottle2neck, cardinality=1, base_width=28, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dla('dla60_res2net', pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dla60_res2next(pretrained=None, num_classes=1000, in_chans=3, **kwargs): def dla60_res2next(pretrained=False,**kwargs):
default_cfg = default_cfgs['dla60_res2next'] model_kwargs = dict(
model = DLA(levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
block=DlaBottle2neck, cardinality=8, base_width=4, block=DlaBottle2neck, cardinality=8, base_width=4, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dla('dla60_res2next', pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dla34(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-34 def dla34(pretrained=False, **kwargs): # DLA-34
default_cfg = default_cfgs['dla34'] model_kwargs = dict(
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic, levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512],
num_classes=num_classes, in_chans=in_chans, **kwargs) block=DlaBasic, **kwargs)
model.default_cfg = default_cfg return _create_dla('dla34', pretrained, **model_kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dla46_c(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-46-C def dla46_c(pretrained=False, **kwargs): # DLA-46-C
default_cfg = default_cfgs['dla46_c'] model_kwargs = dict(
model = DLA(levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
block=DlaBottleneck, num_classes=num_classes, in_chans=in_chans, **kwargs) block=DlaBottleneck, **kwargs)
model.default_cfg = default_cfg return _create_dla('dla46_c', pretrained, **model_kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dla46x_c(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-46-C def dla46x_c(pretrained=False, **kwargs): # DLA-X-46-C
default_cfg = default_cfgs['dla46x_c'] model_kwargs = dict(
model = DLA(levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
block=DlaBottleneck, cardinality=32, base_width=4, block=DlaBottleneck, cardinality=32, base_width=4, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dla('dla46x_c', pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dla60x_c(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-60-C def dla60x_c(pretrained=False, **kwargs): # DLA-X-60-C
default_cfg = default_cfgs['dla60x_c'] model_kwargs = dict(
model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256],
block=DlaBottleneck, cardinality=32, base_width=4, block=DlaBottleneck, cardinality=32, base_width=4, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dla('dla60x_c', pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dla60(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-60 def dla60(pretrained=False, **kwargs): # DLA-60
default_cfg = default_cfgs['dla60'] model_kwargs = dict(
model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, num_classes=num_classes, in_chans=in_chans, **kwargs) block=DlaBottleneck, **kwargs)
model.default_cfg = default_cfg return _create_dla('dla60', pretrained, **model_kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dla60x(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-60 def dla60x(pretrained=False, **kwargs): # DLA-X-60
default_cfg = default_cfgs['dla60x'] model_kwargs = dict(
model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=32, base_width=4, block=DlaBottleneck, cardinality=32, base_width=4, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dla('dla60x', pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dla102(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-102 def dla102(pretrained=False, **kwargs): # DLA-102
default_cfg = default_cfgs['dla102'] model_kwargs = dict(
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, residual_root=True, block=DlaBottleneck, residual_root=True, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dla('dla102', pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dla102x(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-102 def dla102x(pretrained=False, **kwargs): # DLA-X-102
default_cfg = default_cfgs['dla102x'] model_kwargs = dict(
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dla('dla102x', pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dla102x2(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-102 64 def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
default_cfg = default_cfgs['dla102x2'] model_kwargs = dict(
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dla('dla102x2', pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dla169(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-169 def dla169(pretrained=False, **kwargs): # DLA-169
default_cfg = default_cfgs['dla169'] model_kwargs = dict(
model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024], levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, residual_root=True, block=DlaBottleneck, residual_root=True, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dla('dla169', pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -12,7 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, get_padding from .layers import SelectAdaptivePool2d, get_padding
from .registry import register_model from .registry import register_model
@ -141,13 +141,15 @@ class Xception65(nn.Module):
# Entry flow # Entry flow
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = norm_layer(num_features=32, **norm_kwargs) self.bn1 = norm_layer(num_features=32, **norm_kwargs)
self.relu = nn.ReLU(inplace=True) self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = norm_layer(num_features=64) self.bn2 = norm_layer(num_features=64)
self.act2 = nn.ReLU(inplace=True)
self.block1 = Block( self.block1 = Block(
64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs) 64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.block1_act = nn.ReLU(inplace=True)
self.block2 = Block( self.block2 = Block(
128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs) 128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.block3 = Block( self.block3 = Block(
@ -162,22 +164,34 @@ class Xception65(nn.Module):
self.block20 = Block( self.block20 = Block(
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0], 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0],
norm_layer=norm_layer, norm_kwargs=norm_kwargs) norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.block20_act = nn.ReLU(inplace=True)
self.conv3 = SeparableConv2d( self.conv3 = SeparableConv2d(
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], 1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
norm_layer=norm_layer, norm_kwargs=norm_kwargs) norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.bn3 = norm_layer(num_features=1536, **norm_kwargs) self.bn3 = norm_layer(num_features=1536, **norm_kwargs)
self.act3 = nn.ReLU(inplace=True)
self.conv4 = SeparableConv2d( self.conv4 = SeparableConv2d(
1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], 1536, 1536, 3, stride=1, dilation=exit_block_dilations[1],
norm_layer=norm_layer, norm_kwargs=norm_kwargs) norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.bn4 = norm_layer(num_features=1536, **norm_kwargs) self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
self.act4 = nn.ReLU(inplace=True)
self.num_features = 2048 self.num_features = 2048
self.conv5 = SeparableConv2d( self.conv5 = SeparableConv2d(
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1], 1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
norm_layer=norm_layer, norm_kwargs=norm_kwargs) norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs) self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
self.act5 = nn.ReLU(inplace=True)
self.feature_info = [
dict(num_chs=64, reduction=2, module='act2'),
dict(num_chs=128, reduction=4, module='block1_act'),
dict(num_chs=256, reduction=8, module='block3.rep.act1'),
dict(num_chs=728, reduction=16, module='block20.rep.act1'),
dict(num_chs=2048, reduction=32, module='act5'),
]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -193,15 +207,14 @@ class Xception65(nn.Module):
# Entry flow # Entry flow
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
x = self.relu(x) x = self.act1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.bn2(x) x = self.bn2(x)
x = self.relu(x) x = self.act2(x)
x = self.block1(x) x = self.block1(x)
# add relu here x = self.block1_act(x)
x = self.relu(x)
# c1 = x # c1 = x
x = self.block2(x) x = self.block2(x)
# c2 = x # c2 = x
@ -213,18 +226,18 @@ class Xception65(nn.Module):
# Exit flow # Exit flow
x = self.block20(x) x = self.block20(x)
x = self.relu(x) x = self.block20_act(x)
x = self.conv3(x) x = self.conv3(x)
x = self.bn3(x) x = self.bn3(x)
x = self.relu(x) x = self.act3(x)
x = self.conv4(x) x = self.conv4(x)
x = self.bn4(x) x = self.bn4(x)
x = self.relu(x) x = self.act4(x)
x = self.conv5(x) x = self.conv5(x)
x = self.bn5(x) x = self.bn5(x)
x = self.relu(x) x = self.act5(x)
return x return x
def forward(self, x): def forward(self, x):
@ -236,13 +249,14 @@ class Xception65(nn.Module):
return x return x
def _create_gluon_xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
Xception65, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(use_hooks=True), **kwargs)
@register_model @register_model
def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def gluon_xception65(pretrained=False, **kwargs):
""" Modified Aligned Xception-65 """ Modified Aligned Xception-65
""" """
default_cfg = default_cfgs['gluon_xception65'] return _create_gluon_xception('gluon_xception65', pretrained, **kwargs)
model = Xception65(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -8,7 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
from .features import FeatureNet from .features import FeatureNet, FeatureHookNet
from .layers import Conv2dSame from .layers import Conv2dSame
@ -207,6 +207,7 @@ def build_model_with_cfg(
default_cfg: dict, default_cfg: dict,
model_cfg: dict = None, model_cfg: dict = None,
feature_cfg: dict = None, feature_cfg: dict = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Callable = None, pretrained_filter_fn: Callable = None,
**kwargs): **kwargs):
pruned = kwargs.pop('pruned', False) pruned = kwargs.pop('pruned', False)
@ -230,10 +231,18 @@ def build_model_with_cfg(
model, model,
num_classes=kwargs.get('num_classes', 0), num_classes=kwargs.get('num_classes', 0),
in_chans=kwargs.get('in_chans', 3), in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn) filter_fn=pretrained_filter_fn, strict=pretrained_strict)
if features: if features:
feature_cls = feature_cfg.pop('feature_cls', FeatureNet) feature_cls = feature_cfg.pop('feature_cls', FeatureNet)
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()
if feature_cls == 'hook' or feature_cls == 'featurehooknet':
feature_cls = FeatureHookNet
else:
assert False, f'Unknown feature class {feature_cls}'
if feature_cls == FeatureHookNet and hasattr(model, 'reset_classifier'):
model.reset_classifier(0)
model = feature_cls(model, **feature_cfg) model = feature_cls(model, **feature_cfg)
return model return model

@ -735,6 +735,7 @@ class HighResolutionNet(nn.Module):
def _create_hrnet(variant, pretrained, **model_kwargs): def _create_hrnet(variant, pretrained, **model_kwargs):
assert not model_kwargs.pop('features_only', False) # feature extraction not figured out yet
return build_model_with_cfg( return build_model_with_cfg(
HighResolutionNet, variant, pretrained, default_cfg=default_cfgs[variant], HighResolutionNet, variant, pretrained, default_cfg=default_cfgs[variant],
model_cfg=cfg_cls[variant], **model_kwargs) model_cfg=cfg_cls[variant], **model_kwargs)

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .registry import register_model from .registry import register_model
from .layers import trunc_normal_, SelectAdaptivePool2d from .layers import trunc_normal_, SelectAdaptivePool2d
@ -44,231 +44,6 @@ default_cfgs = {
} }
class InceptionV3Aux(nn.Module):
"""InceptionV3 with AuxLogits
"""
def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
super(InceptionV3Aux, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
if inception_blocks is None:
inception_blocks = [
BasicConv2d, InceptionA, InceptionB, InceptionC,
InceptionD, InceptionE, InceptionAux
]
assert len(inception_blocks) == 7
conv_block = inception_blocks[0]
inception_a = inception_blocks[1]
inception_b = inception_blocks[2]
inception_c = inception_blocks[3]
inception_d = inception_blocks[4]
inception_e = inception_blocks[5]
inception_aux = inception_blocks[6]
self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
self.Mixed_5b = inception_a(192, pool_features=32)
self.Mixed_5c = inception_a(256, pool_features=64)
self.Mixed_5d = inception_a(288, pool_features=64)
self.Mixed_6a = inception_b(288)
self.Mixed_6b = inception_c(768, channels_7x7=128)
self.Mixed_6c = inception_c(768, channels_7x7=160)
self.Mixed_6d = inception_c(768, channels_7x7=160)
self.Mixed_6e = inception_c(768, channels_7x7=192)
self.AuxLogits = inception_aux(768, num_classes)
self.Mixed_7a = inception_d(768)
self.Mixed_7b = inception_e(1280)
self.Mixed_7c = inception_e(2048)
self.num_features = 2048
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
trunc_normal_(m.weight, std=stddev)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
aux = self.AuxLogits(x) if self.training else None
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8
return x, aux
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
if self.num_classes > 0:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
else:
self.fc = nn.Identity()
def forward(self, x):
x, aux = self.forward_features(x)
x = self.global_pool(x).flatten(1)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x, aux
class InceptionV3(nn.Module):
"""Inception-V3 with no AuxLogits
FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns
"""
def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
super(InceptionV3, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
if inception_blocks is None:
inception_blocks = [
BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE]
assert len(inception_blocks) >= 6
conv_block = inception_blocks[0]
inception_a = inception_blocks[1]
inception_b = inception_blocks[2]
inception_c = inception_blocks[3]
inception_d = inception_blocks[4]
inception_e = inception_blocks[5]
self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
self.Mixed_5b = inception_a(192, pool_features=32)
self.Mixed_5c = inception_a(256, pool_features=64)
self.Mixed_5d = inception_a(288, pool_features=64)
self.Mixed_6a = inception_b(288)
self.Mixed_6b = inception_c(768, channels_7x7=128)
self.Mixed_6c = inception_c(768, channels_7x7=160)
self.Mixed_6d = inception_c(768, channels_7x7=160)
self.Mixed_6e = inception_c(768, channels_7x7=192)
self.Mixed_7a = inception_d(768)
self.Mixed_7b = inception_e(1280)
self.Mixed_7c = inception_e(2048)
self.num_features = 2048
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(2048, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
trunc_normal_(m.weight, std=stddev)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8
return x
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
if self.num_classes > 0:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
else:
self.fc = nn.Identity()
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x
class InceptionA(nn.Module): class InceptionA(nn.Module):
def __init__(self, in_channels, pool_features, conv_block=None): def __init__(self, in_channels, pool_features, conv_block=None):
@ -504,26 +279,163 @@ class BasicConv2d(nn.Module):
return F.relu(x, inplace=True) return F.relu(x, inplace=True)
class InceptionV3(nn.Module):
"""Inception-V3 with no AuxLogits
FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns
"""
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False):
super(InceptionV3, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.aux_logits = aux_logits
self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
self.Mixed_5b = InceptionA(192, pool_features=32)
self.Mixed_5c = InceptionA(256, pool_features=64)
self.Mixed_5d = InceptionA(288, pool_features=64)
self.Mixed_6a = InceptionB(288)
self.Mixed_6b = InceptionC(768, channels_7x7=128)
self.Mixed_6c = InceptionC(768, channels_7x7=160)
self.Mixed_6d = InceptionC(768, channels_7x7=160)
self.Mixed_6e = InceptionC(768, channels_7x7=192)
if aux_logits:
self.AuxLogits = InceptionAux(768, num_classes)
else:
self.AuxLogits = None
self.Mixed_7a = InceptionD(768)
self.Mixed_7b = InceptionE(1280)
self.Mixed_7c = InceptionE(2048)
self.feature_info = [
dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'),
dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'),
dict(num_chs=288, reduction=8, module='Mixed_5d'),
dict(num_chs=768, reduction=16, module='Mixed_6e'),
dict(num_chs=2048, reduction=32, module='Mixed_7c'),
]
self.num_features = 2048
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(2048, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
trunc_normal_(m.weight, std=stddev)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward_preaux(self, x):
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = self.Pool1(x)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = self.Pool2(x)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
return x
def forward_postaux(self, x):
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8
return x
def forward_features(self, x):
x = self.forward_preaux(x)
x = self.forward_postaux(x)
return x
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
if self.num_classes > 0:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
else:
self.fc = nn.Identity()
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x
class InceptionV3Aux(InceptionV3):
"""InceptionV3 with AuxLogits
"""
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True):
super(InceptionV3Aux, self).__init__(
num_classes, in_chans, drop_rate, global_pool, aux_logits)
def forward_features(self, x):
x = self.forward_preaux(x)
aux = self.AuxLogits(x) if self.training else None
x = self.forward_postaux(x)
return x, aux
def forward(self, x):
x, aux = self.forward_features(x)
x = self.global_pool(x).flatten(1)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x, aux
def _create_inception_v3(variant, pretrained=False, **kwargs): def _create_inception_v3(variant, pretrained=False, **kwargs):
assert not kwargs.pop('features_only', False)
default_cfg = default_cfgs[variant] default_cfg = default_cfgs[variant]
aux_logits = kwargs.pop('aux_logits', False) aux_logits = kwargs.pop('aux_logits', False)
if aux_logits: if aux_logits:
model_class = InceptionV3Aux assert not kwargs.pop('features_only', False)
model_cls = InceptionV3Aux
load_strict = default_cfg['has_aux'] load_strict = default_cfg['has_aux']
else: else:
model_class = InceptionV3 model_cls = InceptionV3
load_strict = not default_cfg['has_aux'] load_strict = not default_cfg['has_aux']
return build_model_with_cfg(
model = model_class(**kwargs) model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
model.default_cfg = default_cfg pretrained_strict=load_strict, **kwargs)
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0),
in_chans=kwargs.get('in_chans', 3),
strict=load_strict)
return model
@register_model @register_model

@ -7,7 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from .registry import register_model from .registry import register_model
@ -39,9 +39,9 @@ class BasicConv2d(nn.Module):
return x return x
class Mixed_3a(nn.Module): class Mixed3a(nn.Module):
def __init__(self): def __init__(self):
super(Mixed_3a, self).__init__() super(Mixed3a, self).__init__()
self.maxpool = nn.MaxPool2d(3, stride=2) self.maxpool = nn.MaxPool2d(3, stride=2)
self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2)
@ -52,9 +52,9 @@ class Mixed_3a(nn.Module):
return out return out
class Mixed_4a(nn.Module): class Mixed4a(nn.Module):
def __init__(self): def __init__(self):
super(Mixed_4a, self).__init__() super(Mixed4a, self).__init__()
self.branch0 = nn.Sequential( self.branch0 = nn.Sequential(
BasicConv2d(160, 64, kernel_size=1, stride=1), BasicConv2d(160, 64, kernel_size=1, stride=1),
@ -75,9 +75,9 @@ class Mixed_4a(nn.Module):
return out return out
class Mixed_5a(nn.Module): class Mixed5a(nn.Module):
def __init__(self): def __init__(self):
super(Mixed_5a, self).__init__() super(Mixed5a, self).__init__()
self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2)
self.maxpool = nn.MaxPool2d(3, stride=2) self.maxpool = nn.MaxPool2d(3, stride=2)
@ -88,9 +88,9 @@ class Mixed_5a(nn.Module):
return out return out
class Inception_A(nn.Module): class InceptionA(nn.Module):
def __init__(self): def __init__(self):
super(Inception_A, self).__init__() super(InceptionA, self).__init__()
self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1)
self.branch1 = nn.Sequential( self.branch1 = nn.Sequential(
@ -118,9 +118,9 @@ class Inception_A(nn.Module):
return out return out
class Reduction_A(nn.Module): class ReductionA(nn.Module):
def __init__(self): def __init__(self):
super(Reduction_A, self).__init__() super(ReductionA, self).__init__()
self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2)
self.branch1 = nn.Sequential( self.branch1 = nn.Sequential(
@ -139,9 +139,9 @@ class Reduction_A(nn.Module):
return out return out
class Inception_B(nn.Module): class InceptionB(nn.Module):
def __init__(self): def __init__(self):
super(Inception_B, self).__init__() super(InceptionB, self).__init__()
self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1)
self.branch1 = nn.Sequential( self.branch1 = nn.Sequential(
@ -172,9 +172,9 @@ class Inception_B(nn.Module):
return out return out
class Reduction_B(nn.Module): class ReductionB(nn.Module):
def __init__(self): def __init__(self):
super(Reduction_B, self).__init__() super(ReductionB, self).__init__()
self.branch0 = nn.Sequential( self.branch0 = nn.Sequential(
BasicConv2d(1024, 192, kernel_size=1, stride=1), BasicConv2d(1024, 192, kernel_size=1, stride=1),
@ -198,9 +198,9 @@ class Reduction_B(nn.Module):
return out return out
class Inception_C(nn.Module): class InceptionC(nn.Module):
def __init__(self): def __init__(self):
super(Inception_C, self).__init__() super(InceptionC, self).__init__()
self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1)
@ -241,8 +241,9 @@ class Inception_C(nn.Module):
class InceptionV4(nn.Module): class InceptionV4(nn.Module):
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'): def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
super(InceptionV4, self).__init__() super(InceptionV4, self).__init__()
assert output_stride == 32
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = 1536 self.num_features = 1536
@ -251,26 +252,33 @@ class InceptionV4(nn.Module):
BasicConv2d(in_chans, 32, kernel_size=3, stride=2), BasicConv2d(in_chans, 32, kernel_size=3, stride=2),
BasicConv2d(32, 32, kernel_size=3, stride=1), BasicConv2d(32, 32, kernel_size=3, stride=1),
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
Mixed_3a(), Mixed3a(),
Mixed_4a(), Mixed4a(),
Mixed_5a(), Mixed5a(),
Inception_A(), InceptionA(),
Inception_A(), InceptionA(),
Inception_A(), InceptionA(),
Inception_A(), InceptionA(),
Reduction_A(), # Mixed_6a ReductionA(), # Mixed6a
Inception_B(), InceptionB(),
Inception_B(), InceptionB(),
Inception_B(), InceptionB(),
Inception_B(), InceptionB(),
Inception_B(), InceptionB(),
Inception_B(), InceptionB(),
Inception_B(), InceptionB(),
Reduction_B(), # Mixed_7a ReductionB(), # Mixed7a
Inception_C(), InceptionC(),
Inception_C(), InceptionC(),
Inception_C(), InceptionC(),
) )
self.feature_info = [
dict(num_chs=64, reduction=2, module='features.2'),
dict(num_chs=160, reduction=4, module='features.3'),
dict(num_chs=384, reduction=8, module='features.9'),
dict(num_chs=1024, reduction=16, module='features.17'),
dict(num_chs=1536, reduction=32, module='features.21'),
]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -298,11 +306,12 @@ class InceptionV4(nn.Module):
return x return x
def _create_inception_v4(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
InceptionV4, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), **kwargs)
@register_model @register_model
def inception_v4(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def inception_v4(pretrained=False, **kwargs):
default_cfg = default_cfgs['inception_v4'] return _create_inception_v4('inception_v4', pretrained, **kwargs)
model = InceptionV4(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -1,8 +1,11 @@
"""
"""
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d
from .registry import register_model from .registry import register_model
@ -484,8 +487,15 @@ class NASNetALarge(nn.Module):
self.cell_17 = NormalCell( self.cell_17 = NormalCell(
in_chs_left=24 * channels, out_chs_left=4 * channels, in_chs_left=24 * channels, out_chs_left=4 * channels,
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
self.act = nn.ReLU(inplace=True) self.act = nn.ReLU(inplace=True)
self.feature_info = [
dict(num_chs=96, reduction=2, module='conv0'),
dict(num_chs=168, reduction=4, module='cell_stem_1.conv_1x1.act'),
dict(num_chs=1008, reduction=8, module='reduction_cell_0.conv_1x1.act'),
dict(num_chs=2016, reduction=16, module='reduction_cell_1.conv_1x1.act'),
dict(num_chs=4032, reduction=32, module='act'),
]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -503,11 +513,9 @@ class NASNetALarge(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x_conv0 = self.conv0(x) x_conv0 = self.conv0(x)
#0
x_stem_0 = self.cell_stem_0(x_conv0) x_stem_0 = self.cell_stem_0(x_conv0)
x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
#1
x_cell_0 = self.cell_0(x_stem_1, x_stem_0) x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
x_cell_1 = self.cell_1(x_cell_0, x_stem_1) x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
@ -515,7 +523,6 @@ class NASNetALarge(nn.Module):
x_cell_3 = self.cell_3(x_cell_2, x_cell_1) x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
x_cell_4 = self.cell_4(x_cell_3, x_cell_2) x_cell_4 = self.cell_4(x_cell_3, x_cell_2)
x_cell_5 = self.cell_5(x_cell_4, x_cell_3) x_cell_5 = self.cell_5(x_cell_4, x_cell_3)
#2
x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4) x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4)
x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4) x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4)
@ -524,7 +531,6 @@ class NASNetALarge(nn.Module):
x_cell_9 = self.cell_9(x_cell_8, x_cell_7) x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
x_cell_10 = self.cell_10(x_cell_9, x_cell_8) x_cell_10 = self.cell_10(x_cell_9, x_cell_8)
x_cell_11 = self.cell_11(x_cell_10, x_cell_9) x_cell_11 = self.cell_11(x_cell_10, x_cell_9)
#3
x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10) x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10)
x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10) x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10)
@ -534,8 +540,6 @@ class NASNetALarge(nn.Module):
x_cell_16 = self.cell_16(x_cell_15, x_cell_14) x_cell_16 = self.cell_16(x_cell_15, x_cell_14)
x_cell_17 = self.cell_17(x_cell_16, x_cell_15) x_cell_17 = self.cell_17(x_cell_16, x_cell_15)
x = self.act(x_cell_17) x = self.act(x_cell_17)
#4
return x return x
def forward(self, x): def forward(self, x):
@ -547,14 +551,16 @@ class NASNetALarge(nn.Module):
return x return x
def _create_nasnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
NASNetALarge, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet
**kwargs)
@register_model @register_model
def nasnetalarge(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def nasnetalarge(pretrained=False, **kwargs):
"""NASNet-A large model architecture. """NASNet-A large model architecture.
""" """
default_cfg = default_cfgs['nasnetalarge'] model_kwargs = dict(pad_type='same', **kwargs)
model = NASNetALarge(num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_nasnet('nasnetalarge', pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -5,15 +5,13 @@
https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py
""" """
from __future__ import print_function, division, absolute_import
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d
from .registry import register_model from .registry import register_model
@ -147,35 +145,35 @@ class CellBase(nn.Module):
class CellStem0(CellBase): class CellStem0(CellBase):
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding=''): def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
super(CellStem0, self).__init__() super(CellStem0, self).__init__()
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding) self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type)
self.comb_iter_0_left = BranchSeparables( self.comb_iter_0_left = BranchSeparables(
in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=padding) in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type)
self.comb_iter_0_right = nn.Sequential(OrderedDict([ self.comb_iter_0_right = nn.Sequential(OrderedDict([
('max_pool', create_pool2d('max', 3, stride=2, padding=padding)), ('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)),
('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=padding)), ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)),
('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)), ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)),
])) ]))
self.comb_iter_1_left = BranchSeparables( self.comb_iter_1_left = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=padding) out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type)
self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=padding) self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type)
self.comb_iter_2_left = BranchSeparables( self.comb_iter_2_left = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=padding) out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type)
self.comb_iter_2_right = BranchSeparables( self.comb_iter_2_right = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=padding) out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type)
self.comb_iter_3_left = BranchSeparables( self.comb_iter_3_left = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=3, padding=padding) out_chs_right, out_chs_right, kernel_size=3, padding=pad_type)
self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=padding) self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type)
self.comb_iter_4_left = BranchSeparables( self.comb_iter_4_left = BranchSeparables(
in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=padding) in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type)
self.comb_iter_4_right = ActConvBn( self.comb_iter_4_right = ActConvBn(
out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=padding) out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type)
def forward(self, x_left): def forward(self, x_left):
x_right = self.conv_1x1(x_left) x_right = self.conv_1x1(x_left)
@ -185,12 +183,12 @@ class CellStem0(CellBase):
class Cell(CellBase): class Cell(CellBase):
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding='', def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type='',
is_reduction=False, match_prev_layer_dims=False): is_reduction=False, match_prev_layer_dims=False):
super(Cell, self).__init__() super(Cell, self).__init__()
# If `is_reduction` is set to `True` stride 2 is used for # If `is_reduction` is set to `True` stride 2 is used for
# convolutional and pooling layers to reduce the spatial size of # convolution and pooling layers to reduce the spatial size of
# the output of a cell approximately by a factor of 2. # the output of a cell approximately by a factor of 2.
stride = 2 if is_reduction else 1 stride = 2 if is_reduction else 1
@ -199,32 +197,32 @@ class Cell(CellBase):
# of the left input of a cell approximately by a factor of 2. # of the left input of a cell approximately by a factor of 2.
self.match_prev_layer_dimensions = match_prev_layer_dims self.match_prev_layer_dimensions = match_prev_layer_dims
if match_prev_layer_dims: if match_prev_layer_dims:
self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=padding) self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type)
else: else:
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=padding) self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding) self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type)
self.comb_iter_0_left = BranchSeparables( self.comb_iter_0_left = BranchSeparables(
out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=padding) out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type)
self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=padding) self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
self.comb_iter_1_left = BranchSeparables( self.comb_iter_1_left = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=padding) out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type)
self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=padding) self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
self.comb_iter_2_left = BranchSeparables( self.comb_iter_2_left = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=padding) out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type)
self.comb_iter_2_right = BranchSeparables( self.comb_iter_2_right = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=padding) out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type)
self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3) self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3)
self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=padding) self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
self.comb_iter_4_left = BranchSeparables( self.comb_iter_4_left = BranchSeparables(
out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=padding) out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type)
if is_reduction: if is_reduction:
self.comb_iter_4_right = ActConvBn( self.comb_iter_4_right = ActConvBn(
out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=padding) out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type)
else: else:
self.comb_iter_4_right = None self.comb_iter_4_right = None
@ -236,7 +234,7 @@ class Cell(CellBase):
class PNASNet5Large(nn.Module): class PNASNet5Large(nn.Module):
def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0.5, global_pool='avg', padding=''): def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''):
super(PNASNet5Large, self).__init__() super(PNASNet5Large, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
@ -248,43 +246,51 @@ class PNASNet5Large(nn.Module):
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None) norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
self.cell_stem_0 = CellStem0( self.cell_stem_0 = CellStem0(
in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, padding=padding) in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type)
self.cell_stem_1 = Cell( self.cell_stem_1 = Cell(
in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, padding=padding, in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, pad_type=pad_type,
match_prev_layer_dims=True, is_reduction=True) match_prev_layer_dims=True, is_reduction=True)
self.cell_0 = Cell( self.cell_0 = Cell(
in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, padding=padding, in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, pad_type=pad_type,
match_prev_layer_dims=True) match_prev_layer_dims=True)
self.cell_1 = Cell( self.cell_1 = Cell(
in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding) in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type)
self.cell_2 = Cell( self.cell_2 = Cell(
in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding) in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type)
self.cell_3 = Cell( self.cell_3 = Cell(
in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding) in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type)
self.cell_4 = Cell( self.cell_4 = Cell(
in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, padding=padding, in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, pad_type=pad_type,
is_reduction=True) is_reduction=True)
self.cell_5 = Cell( self.cell_5 = Cell(
in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding, in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type,
match_prev_layer_dims=True) match_prev_layer_dims=True)
self.cell_6 = Cell( self.cell_6 = Cell(
in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding) in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type)
self.cell_7 = Cell( self.cell_7 = Cell(
in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding) in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type)
self.cell_8 = Cell( self.cell_8 = Cell(
in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, padding=padding, in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, pad_type=pad_type,
is_reduction=True) is_reduction=True)
self.cell_9 = Cell( self.cell_9 = Cell(
in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding, in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type,
match_prev_layer_dims=True) match_prev_layer_dims=True)
self.cell_10 = Cell( self.cell_10 = Cell(
in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding) in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type)
self.cell_11 = Cell( self.cell_11 = Cell(
in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding) in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type)
self.relu = nn.ReLU() self.act = nn.ReLU()
self.feature_info = [
dict(num_chs=96, reduction=2, module='conv_0'),
dict(num_chs=270, reduction=4, module='cell_stem_1.conv_1x1.act'),
dict(num_chs=1080, reduction=8, module='cell_4.conv_1x1.act'),
dict(num_chs=2160, reduction=16, module='cell_8.conv_1x1.act'),
dict(num_chs=4320, reduction=32, module='act'),
]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -316,7 +322,7 @@ class PNASNet5Large(nn.Module):
x_cell_9 = self.cell_9(x_cell_7, x_cell_8) x_cell_9 = self.cell_9(x_cell_7, x_cell_8)
x_cell_10 = self.cell_10(x_cell_8, x_cell_9) x_cell_10 = self.cell_10(x_cell_8, x_cell_9)
x_cell_11 = self.cell_11(x_cell_9, x_cell_10) x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
x = self.relu(x_cell_11) x = self.act(x_cell_11)
return x return x
def forward(self, x): def forward(self, x):
@ -328,16 +334,18 @@ class PNASNet5Large(nn.Module):
return x return x
def _create_pnasnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
PNASNet5Large, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet
**kwargs)
@register_model @register_model
def pnasnet5large(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def pnasnet5large(pretrained=False, **kwargs):
r"""PNASNet-5 model architecture from the r"""PNASNet-5 model architecture from the
`"Progressive Neural Architecture Search" `"Progressive Neural Architecture Search"
<https://arxiv.org/abs/1712.00559>`_ paper. <https://arxiv.org/abs/1712.00559>`_ paper.
""" """
default_cfg = default_cfgs['pnasnet5large'] model_kwargs = dict(pad_type='same', **kwargs)
model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, padding='same', **kwargs) return _create_pnasnet('pnasnet5large', pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -154,6 +154,13 @@ class Xception(nn.Module):
self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1) self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1)
self.bn4 = nn.BatchNorm2d(self.num_features) self.bn4 = nn.BatchNorm2d(self.num_features)
self.act4 = nn.ReLU(inplace=True) self.act4 = nn.ReLU(inplace=True)
self.feature_info = [
dict(num_chs=64, reduction=2, module='act2'),
dict(num_chs=128, reduction=4, module='block2.rep.0'),
dict(num_chs=256, reduction=8, module='block3.rep.0'),
dict(num_chs=728, reduction=16, module='block12.rep.0'),
dict(num_chs=2048, reduction=32, module='act4'),
]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -221,7 +228,7 @@ class Xception(nn.Module):
def _xception(variant, pretrained=False, **kwargs): def _xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
Xception, variant, pretrained, default_cfg=default_cfgs[variant], Xception, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(), **kwargs) feature_cfg=dict(use_hooks=True), **kwargs)
@register_model @register_model

Loading…
Cancel
Save