Cleanup, refactoring of Feature extraction code, add tests, fix tests, non hook feature extraction working with torchscript

pull/175/head
Ross Wightman 4 years ago
parent 6eec3fb4a4
commit 4e61c6a12d

@ -106,3 +106,26 @@ def test_model_forward_torchscript(model_name, batch_size):
assert outputs.shape[0] == batch_size assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs' assert not torch.isnan(outputs).any(), 'Output included NaNs'
EXCLUDE_FEAT_FILTERS = [
'hrnet*', '*pruned*', # hopefully fix at some point
'legacy*', # not going to bother
]
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
model = create_model(model_name, pretrained=False, features_only=True)
model.eval()
expected_channels = model.feature_info.channels()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs)
for e, o in zip(expected_channels, outputs):
assert e == o.shape[1]
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()

@ -75,8 +75,14 @@ class FeatureInfo:
class FeatureHooks: class FeatureHooks:
""" Feature Hook Helper
def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'): This module helps with the setup and extraction of hooks for extracting features from
internal nodes in a model by node name. This works quite well in eager Python but needs
redesign for torcscript.
"""
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
# setup feature hooks # setup feature hooks
modules = {k: v for k, v in named_modules} modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks): for i, h in enumerate(hooks):
@ -92,7 +98,6 @@ class FeatureHooks:
else: else:
assert False, "Unsupported hook type" assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict) self._feature_outputs = defaultdict(OrderedDict)
self.out_as_dict = out_as_dict
def _collect_output_hook(self, hook_id, *args): def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
@ -100,11 +105,8 @@ class FeatureHooks:
x = x[0] # unwrap input tuple x = x[0] # unwrap input tuple
self._feature_outputs[x.device][hook_id] = x self._feature_outputs[x.device][hook_id] = x
def get_output(self, device) -> List[torch.tensor]: # FIXME deal with diff return types for torchscript? def get_output(self, device) -> Dict[str, torch.tensor]:
if self.out_as_dict:
output = self._feature_outputs[device] output = self._feature_outputs[device]
else:
output = list(self._feature_outputs[device].values())
self._feature_outputs[device] = OrderedDict() # clear after reading self._feature_outputs[device] = OrderedDict() # clear after reading
return output return output
@ -123,83 +125,72 @@ def _module_list(module, flatten_sequential=False):
return ml return ml
class LayerGetterHooks(nn.ModuleDict): def _get_feature_info(net, out_indices):
""" LayerGetterHooks feature_info = getattr(net, 'feature_info')
TODO if isinstance(feature_info, FeatureInfo):
""" return feature_info.from_other(out_indices)
elif isinstance(feature_info, (list, tuple)):
def __init__(self, model, feature_info, flatten_sequential=False, out_as_dict=False, out_map=None, return FeatureInfo(net.feature_info, out_indices)
default_hook_type='forward'): else:
modules = _module_list(model, flatten_sequential=flatten_sequential) assert False, "Provided feature_info is not valid"
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type for f in feature_info}
layers = OrderedDict()
hooks = []
for new_name, old_name, module in modules:
layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name):
if fn in remaining:
hooks.append(dict(module=fn, hook_type=remaining[fn]))
del remaining[fn]
if not remaining:
break
assert not remaining, f'Return layers ({remaining}) are not present in model'
super(LayerGetterHooks, self).__init__(layers)
self.hooks = FeatureHooks(hooks, model.named_modules(), out_as_dict=out_as_dict, out_map=out_map)
def forward(self, x) -> Dict[Any, torch.Tensor]:
for name, module in self.items():
x = module(x)
return self.hooks.get_output(x.device)
def _get_return_layers(feature_info, out_map):
module_names = feature_info.module_name()
return_layers = {}
for i, name in enumerate(module_names):
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
return return_layers
class LayerGetterDict(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model as a dictionary
Originally based on concepts from IntermediateLayerGetter at class FeatureDictNet(nn.ModuleDict):
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py """ Feature extractor with OrderedDict return
It has a strong assumption that the modules have been registered into the model in the same Wrap a model and extract features as specified by the out indices, the network is
order as they are used. This means that one should **not** reuse the same nn.Module twice partially re-built from contained modules.
in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly assigned to the model There is a strong assumption that the modules have been registered into the model in the same
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`, so order as they are used. There should be no reuse of the same nn.Module more than once, including
long as `features` is a sequential container assigned to the model). trivial modules like `self.relu = nn.ReLU`.
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
All Sequential containers that are directly assigned to the original model will have their All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1` modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
Arguments: Arguments:
model (nn.Module): model on which we will extract the features model (nn.Module): model from which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names out_indices (tuple[int]): model output indices to extract features for
of the modules for which the activations will be returned as out_map (sequence): list or tuple specifying desired return id for each out index,
the key of the dict, and the value of the dict is the name otherwise str(index) is used
of the returned activation (which the user can specify). feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0] vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model flatten_sequential (bool): whether to flatten sequential modules assigned to model
""" """
def __init__(
def __init__(self, model, return_layers, concat=False, flatten_sequential=False): self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
super(FeatureDictNet, self).__init__()
self.feature_info = _get_feature_info(model, out_indices)
self.concat = feature_concat
self.return_layers = {} self.return_layers = {}
self.concat = concat return_layers = _get_return_layers(self.feature_info, out_map)
modules = _module_list(model, flatten_sequential=flatten_sequential) modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys()) remaining = set(return_layers.keys())
layers = OrderedDict() layers = OrderedDict()
for new_name, old_name, module in modules: for new_name, old_name, module in modules:
layers[new_name] = module layers[new_name] = module
if old_name in remaining: if old_name in remaining:
self.return_layers[new_name] = return_layers[old_name] # return id has to be consistently str type for torchscript
self.return_layers[new_name] = str(return_layers[old_name])
remaining.remove(old_name) remaining.remove(old_name)
if not remaining: if not remaining:
break break
assert not remaining and len(self.return_layers) == len(return_layers), \ assert not remaining and len(self.return_layers) == len(return_layers), \
f'Return layers ({remaining}) are not present in model' f'Return layers ({remaining}) are not present in model'
super(LayerGetterDict, self).__init__(layers) self.update(layers)
def forward(self, x) -> Dict[Any, torch.Tensor]: def _collect(self, x) -> (Dict[str, torch.Tensor]):
out = OrderedDict() out = OrderedDict()
for name, module in self.items(): for name, module in self.items():
x = module(x) x = module(x)
@ -213,131 +204,74 @@ class LayerGetterDict(nn.ModuleDict):
out[out_id] = x out[out_id] = x
return out return out
def forward(self, x) -> Dict[str, torch.Tensor]:
return self._collect(x)
class LayerGetterList(nn.Sequential):
"""
Module wrapper that returns intermediate layers from a model as a list
Originally based on concepts from IntermediateLayerGetter at
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
It has a strong assumption that the modules have been registered into the model in the same
order as they are used. This means that one should **not** reuse the same nn.Module twice
in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly assigned to the model
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`) so
long as `features` is a sequential container assigned to the model and flatten_sequent=True.
All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
Arguments:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model
"""
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
super(LayerGetterList, self).__init__()
self.return_layers = {}
self.concat = concat
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys())
for new_name, orig_name, module in modules:
self.add_module(new_name, module)
if orig_name in remaining:
self.return_layers[new_name] = return_layers[orig_name]
remaining.remove(orig_name)
if not remaining:
break
assert not remaining and len(self.return_layers) == len(return_layers), \
f'Return layers ({remaining}) are not present in model'
def forward(self, x) -> List[torch.Tensor]:
out = []
for name, module in self.named_children():
x = module(x)
if name in self.return_layers:
if isinstance(x, (tuple, list)):
# If model tap is a tuple or list, concat or select first element
# FIXME this may need to be more generic / flexible for some nets
out.append(torch.cat(x, 1) if self.concat else x[0])
else:
out.append(x)
return out
def _resolve_feature_info(net, out_indices, feature_info=None):
if feature_info is None:
feature_info = getattr(net, 'feature_info')
if isinstance(feature_info, FeatureInfo):
return feature_info.from_other(out_indices)
elif isinstance(feature_info, (list, tuple)):
return FeatureInfo(net.feature_info, out_indices)
else:
assert False, "Provided feature_info is not valid"
def _get_return_layers(feature_info, out_map):
module_names = feature_info.module_name()
return_layers = {}
for i, name in enumerate(module_names):
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
return return_layers
class FeatureNet(nn.Module):
""" FeatureNet
Wrap a model and extract features as specified by the out indices, the network class FeatureListNet(FeatureDictNet):
is partially re-built from contained modules using the LayerGetters. """ Feature extractor with list return
Please read the docstrings of the LayerGetter classes, they will not work on all models. See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
""" """
def __init__( def __init__(
self, net, self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, use_hooks=False, out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
feature_info=None, feature_concat=False, flatten_sequential=False): super(FeatureListNet, self).__init__(
super(FeatureNet, self).__init__() model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
self.feature_info = _resolve_feature_info(net, out_indices, feature_info) flatten_sequential=flatten_sequential)
if use_hooks:
self.body = LayerGetterHooks(net, self.feature_info, out_as_dict=out_as_dict, out_map=out_map)
else:
return_layers = _get_return_layers(self.feature_info, out_map)
lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential)
self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args)
def forward(self, x): def forward(self, x) -> (List[torch.Tensor]):
output = self.body(x) return list(self._collect(x).values())
return output
class FeatureHookNet(nn.Module): class FeatureHookNet(nn.ModuleDict):
""" FeatureHookNet """ FeatureHookNet
Wrap a model and extract features specified by the out indices. Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
network in any way.
If `no_rewrite` is False, the model will be re-written as in the
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
Features are extracted via hooks without modifying the underlying network in any way. If only FIXME this does not currently work with Torchscript, see FeatureHooks class
part of the model is used it is up to the caller to remove unneeded layers as this wrapper
does not rewrite and remove unused top-level modules like FeatureNet with LayerGetter.
""" """
def __init__( def __init__(
self, net, self, model,
out_indices=(0, 1, 2, 3, 4), out_as_dict=False, out_map=None, out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
feature_info=None, feature_concat=False): feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
super(FeatureHookNet, self).__init__() super(FeatureHookNet, self).__init__()
self.feature_info = _resolve_feature_info(net, out_indices, feature_info) assert not torch.jit.is_scripting()
self.body = net self.feature_info = _get_feature_info(model, out_indices)
self.hooks = FeatureHooks( self.out_as_dict = out_as_dict
self.feature_info, self.body.named_modules(), out_as_dict=out_as_dict, out_map=out_map) layers = OrderedDict()
hooks = []
if no_rewrite:
assert not flatten_sequential
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
model.reset_classifier(0)
layers['body'] = model
hooks.extend(self.feature_info)
else:
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
for f in self.feature_info}
for new_name, old_name, module in modules:
layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name):
if fn in remaining:
hooks.append(dict(module=fn, hook_type=remaining[fn]))
del remaining[fn]
if not remaining:
break
assert not remaining, f'Return layers ({remaining}) are not present in model'
self.update(layers)
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
def forward(self, x): def forward(self, x):
self.body(x) for name, module in self.items():
return self.hooks.get_output(x.device) x = module(x)
out = self.hooks.get_output(x.device)
return out if self.out_as_dict else list(out.values())

@ -252,7 +252,7 @@ class Xception65(nn.Module):
def _create_gluon_xception(variant, pretrained=False, **kwargs): def _create_gluon_xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
Xception65, variant, pretrained, default_cfg=default_cfgs[variant], Xception65, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(use_hooks=True), **kwargs) feature_cfg=dict(feature_cls='hook'), **kwargs)
@register_model @register_model

@ -8,7 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
from .features import FeatureNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame from .layers import Conv2dSame
@ -234,15 +234,15 @@ def build_model_with_cfg(
filter_fn=pretrained_filter_fn, strict=pretrained_strict) filter_fn=pretrained_filter_fn, strict=pretrained_strict)
if features: if features:
feature_cls = feature_cfg.pop('feature_cls', FeatureNet) feature_cls = FeatureListNet
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str): if isinstance(feature_cls, str):
feature_cls = feature_cls.lower() feature_cls = feature_cls.lower()
if feature_cls == 'hook' or feature_cls == 'featurehooknet': if 'hook' in feature_cls:
feature_cls = FeatureHookNet feature_cls = FeatureHookNet
else: else:
assert False, f'Unknown feature class {feature_cls}' assert False, f'Unknown feature class {feature_cls}'
if feature_cls == FeatureHookNet and hasattr(model, 'reset_classifier'):
model.reset_classifier(0)
model = feature_cls(model, **feature_cfg) model = feature_cls(model, **feature_cfg)
return model return model

@ -211,7 +211,8 @@ class MobileNetV3Features(nn.Module):
return features return features
else: else:
self.blocks(x) self.blocks(x)
return self.feature_hooks.get_output(x.device) out = self.feature_hooks.get_output(x.device)
return list(out.values())
def _create_mnv3(model_kwargs, variant, pretrained=False): def _create_mnv3(model_kwargs, variant, pretrained=False):
@ -220,6 +221,7 @@ def _create_mnv3(model_kwargs, variant, pretrained=False):
model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0) model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None) model_kwargs.pop('head_conv', None)
model_kwargs.pop('head_bias', None)
model_cls = MobileNetV3Features model_cls = MobileNetV3Features
else: else:
load_strict = True load_strict = True

@ -554,7 +554,7 @@ class NASNetALarge(nn.Module):
def _create_nasnet(variant, pretrained=False, **kwargs): def _create_nasnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
NASNetALarge, variant, pretrained, default_cfg=default_cfgs[variant], NASNetALarge, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
**kwargs) **kwargs)

@ -337,7 +337,7 @@ class PNASNet5Large(nn.Module):
def _create_pnasnet(variant, pretrained=False, **kwargs): def _create_pnasnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
PNASNet5Large, variant, pretrained, default_cfg=default_cfgs[variant], PNASNet5Large, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
**kwargs) **kwargs)

@ -74,6 +74,29 @@ class SequentialList(nn.Sequential):
return x return x
class SelectSeq(nn.Module):
def __init__(self, mode='index', index=0):
super(SelectSeq, self).__init__()
self.mode = mode
self.index = index
@torch.jit._overload_method # noqa: F811
def forward(self, x):
# type: (List[torch.Tensor]) -> (torch.Tensor)
pass
@torch.jit._overload_method # noqa: F811
def forward(self, x):
# type: (Tuple[torch.Tensor]) -> (torch.Tensor)
pass
def forward(self, x) -> torch.Tensor:
if self.mode == 'index':
return x[self.index]
else:
return torch.cat(x, dim=1)
def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1): def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1):
if padding is None: if padding is None:
padding = ((stride - 1) + dilation * (k - 1)) // 2 padding = ((stride - 1) + dilation * (k - 1)) // 2
@ -137,8 +160,10 @@ class SelecSLS(nn.Module):
self.stem = conv_bn(in_chans, 32, stride=2) self.stem = conv_bn(in_chans, 32, stride=2)
self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']]) self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']])
self.from_seq = SelectSeq() # from List[tensor] -> Tensor in module compatible way
self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']]) self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']])
self.num_features = cfg['num_features'] self.num_features = cfg['num_features']
self.feature_info = cfg['feature_info']
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -165,7 +190,7 @@ class SelecSLS(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.features(x) x = self.features(x)
x = self.head(x[0]) x = self.head(self.from_seq(x))
return x return x
def forward(self, x): def forward(self, x):
@ -297,6 +322,7 @@ def _create_selecsls(variant, pretrained, model_kwargs):
]) ])
else: else:
raise ValueError('Invalid net configuration ' + variant + ' !!!') raise ValueError('Invalid net configuration ' + variant + ' !!!')
cfg['feature_info'] = feature_info
# this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
return build_model_with_cfg( return build_model_with_cfg(

@ -160,6 +160,9 @@ class Bottleneck(nn.Module):
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3), conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
aa_layer(channels=planes, filt_size=3, stride=2)) aa_layer(channels=planes, filt_size=3, stride=2))
reduce_layer_planes = max(planes * self.expansion // 8, 64)
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None
self.conv3 = conv2d_iabn( self.conv3 = conv2d_iabn(
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
@ -167,9 +170,6 @@ class Bottleneck(nn.Module):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
reduce_layer_planes = max(planes * self.expansion // 8, 64)
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None
def forward(self, x): def forward(self, x):
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(x) residual = self.downsample(x)
@ -225,8 +225,8 @@ class TResNet(nn.Module):
dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D? dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D?
dict(num_chs=self.planes, reduction=4, module='body.layer1'), dict(num_chs=self.planes, reduction=4, module='body.layer1'),
dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'), dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'),
dict(num_chs=self.planes * 4, reduction=16, module='body.layer3'), dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'),
dict(num_chs=self.planes * 8, reduction=32, module='body.layer4'), dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'),
] ]
# head # head

@ -228,7 +228,7 @@ class Xception(nn.Module):
def _xception(variant, pretrained=False, **kwargs): def _xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
Xception, variant, pretrained, default_cfg=default_cfgs[variant], Xception, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(use_hooks=True), **kwargs) feature_cfg=dict(feature_cls='hook'), **kwargs)
@register_model @register_model

@ -174,7 +174,7 @@ class XceptionAligned(nn.Module):
def _xception(variant, pretrained=False, **kwargs): def _xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
XceptionAligned, variant, pretrained, default_cfg=default_cfgs[variant], XceptionAligned, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True, use_hooks=True), **kwargs) feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), **kwargs)
@register_model @register_model

@ -100,6 +100,8 @@ def validate(args):
# might as well try to validate something # might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
if args.legacy_jit:
set_jit_legacy()
# create model # create model
model = create_model( model = create_model(
@ -119,8 +121,6 @@ def validate(args):
model, test_time_pool = apply_test_time_pool(model, data_config, args) model, test_time_pool = apply_test_time_pool(model, data_config, args)
if args.torchscript: if args.torchscript:
if args.legacy_jit:
set_jit_legacy()
torch.jit.optimized_execution(True) torch.jit.optimized_execution(True)
model = torch.jit.script(model) model = torch.jit.script(model)

Loading…
Cancel
Save