From 4e61c6a12d6e1a3fa1be554b6bff0b07fc1025fb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 20 Jul 2020 16:10:31 -0700 Subject: [PATCH] Cleanup, refactoring of Feature extraction code, add tests, fix tests, non hook feature extraction working with torchscript --- tests/test_models.py | 23 +++ timm/models/features.py | 272 ++++++++++++-------------------- timm/models/gluon_xception.py | 2 +- timm/models/helpers.py | 20 +-- timm/models/mobilenetv3.py | 4 +- timm/models/nasnet.py | 2 +- timm/models/pnasnet.py | 2 +- timm/models/selecsls.py | 28 +++- timm/models/tresnet.py | 10 +- timm/models/xception.py | 2 +- timm/models/xception_aligned.py | 2 +- validate.py | 4 +- 12 files changed, 178 insertions(+), 193 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 2babd74a..e68e6599 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -106,3 +106,26 @@ def test_model_forward_torchscript(model_name, batch_size): assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +EXCLUDE_FEAT_FILTERS = [ + 'hrnet*', '*pruned*', # hopefully fix at some point + 'legacy*', # not going to bother +] + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_features(model_name, batch_size): + """Run a single forward pass with each model in feature extraction mode""" + model = create_model(model_name, pretrained=False, features_only=True) + model.eval() + expected_channels = model.feature_info.channels() + assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 + input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already... + outputs = model(torch.randn((batch_size, *input_size))) + assert len(expected_channels) == len(outputs) + for e, o in zip(expected_channels, outputs): + assert e == o.shape[1] + assert o.shape[0] == batch_size + assert not torch.isnan(o).any() diff --git a/timm/models/features.py b/timm/models/features.py index 46842f5d..757811af 100644 --- a/timm/models/features.py +++ b/timm/models/features.py @@ -75,8 +75,14 @@ class FeatureInfo: class FeatureHooks: + """ Feature Hook Helper - def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'): + This module helps with the setup and extraction of hooks for extracting features from + internal nodes in a model by node name. This works quite well in eager Python but needs + redesign for torcscript. + """ + + def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): # setup feature hooks modules = {k: v for k, v in named_modules} for i, h in enumerate(hooks): @@ -92,7 +98,6 @@ class FeatureHooks: else: assert False, "Unsupported hook type" self._feature_outputs = defaultdict(OrderedDict) - self.out_as_dict = out_as_dict def _collect_output_hook(self, hook_id, *args): x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre @@ -100,11 +105,8 @@ class FeatureHooks: x = x[0] # unwrap input tuple self._feature_outputs[x.device][hook_id] = x - def get_output(self, device) -> List[torch.tensor]: # FIXME deal with diff return types for torchscript? - if self.out_as_dict: - output = self._feature_outputs[device] - else: - output = list(self._feature_outputs[device].values()) + def get_output(self, device) -> Dict[str, torch.tensor]: + output = self._feature_outputs[device] self._feature_outputs[device] = OrderedDict() # clear after reading return output @@ -123,83 +125,72 @@ def _module_list(module, flatten_sequential=False): return ml -class LayerGetterHooks(nn.ModuleDict): - """ LayerGetterHooks - TODO - """ - - def __init__(self, model, feature_info, flatten_sequential=False, out_as_dict=False, out_map=None, - default_hook_type='forward'): - modules = _module_list(model, flatten_sequential=flatten_sequential) - remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type for f in feature_info} - layers = OrderedDict() - hooks = [] - for new_name, old_name, module in modules: - layers[new_name] = module - for fn, fm in module.named_modules(prefix=old_name): - if fn in remaining: - hooks.append(dict(module=fn, hook_type=remaining[fn])) - del remaining[fn] - if not remaining: - break - assert not remaining, f'Return layers ({remaining}) are not present in model' - super(LayerGetterHooks, self).__init__(layers) - self.hooks = FeatureHooks(hooks, model.named_modules(), out_as_dict=out_as_dict, out_map=out_map) +def _get_feature_info(net, out_indices): + feature_info = getattr(net, 'feature_info') + if isinstance(feature_info, FeatureInfo): + return feature_info.from_other(out_indices) + elif isinstance(feature_info, (list, tuple)): + return FeatureInfo(net.feature_info, out_indices) + else: + assert False, "Provided feature_info is not valid" - def forward(self, x) -> Dict[Any, torch.Tensor]: - for name, module in self.items(): - x = module(x) - return self.hooks.get_output(x.device) +def _get_return_layers(feature_info, out_map): + module_names = feature_info.module_name() + return_layers = {} + for i, name in enumerate(module_names): + return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] + return return_layers -class LayerGetterDict(nn.ModuleDict): - """ - Module wrapper that returns intermediate layers from a model as a dictionary - Originally based on concepts from IntermediateLayerGetter at - https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py +class FeatureDictNet(nn.ModuleDict): + """ Feature extractor with OrderedDict return - It has a strong assumption that the modules have been registered into the model in the same - order as they are used. This means that one should **not** reuse the same nn.Module twice - in the forward if you want this to work. + Wrap a model and extract features as specified by the out indices, the network is + partially re-built from contained modules. - Additionally, it is only able to query submodules that are directly assigned to the model - class (`model.feature1`) or at most one Sequential container deep (`model.features.1`, so - long as `features` is a sequential container assigned to the model). + There is a strong assumption that the modules have been registered into the model in the same + order as they are used. There should be no reuse of the same nn.Module more than once, including + trivial modules like `self.relu = nn.ReLU`. + Only submodules that are directly assigned to the model class (`model.feature1`) or at most + one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. All Sequential containers that are directly assigned to the original model will have their modules assigned to this module with the name `model.features.1` being changed to `model.features_1` Arguments: - model (nn.Module): model on which we will extract the features - return_layers (Dict[name, new_name]): a dict containing the names - of the modules for which the activations will be returned as - the key of the dict, and the value of the dict is the name - of the returned activation (which the user can specify). - concat (bool): whether to concatenate intermediate features that are lists or tuples + model (nn.Module): model from which we will extract the features + out_indices (tuple[int]): model output indices to extract features for + out_map (sequence): list or tuple specifying desired return id for each out index, + otherwise str(index) is used + feature_concat (bool): whether to concatenate intermediate features that are lists or tuples vs select element [0] flatten_sequential (bool): whether to flatten sequential modules assigned to model - """ - - def __init__(self, model, return_layers, concat=False, flatten_sequential=False): + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureDictNet, self).__init__() + self.feature_info = _get_feature_info(model, out_indices) + self.concat = feature_concat self.return_layers = {} - self.concat = concat + return_layers = _get_return_layers(self.feature_info, out_map) modules = _module_list(model, flatten_sequential=flatten_sequential) remaining = set(return_layers.keys()) layers = OrderedDict() for new_name, old_name, module in modules: layers[new_name] = module if old_name in remaining: - self.return_layers[new_name] = return_layers[old_name] + # return id has to be consistently str type for torchscript + self.return_layers[new_name] = str(return_layers[old_name]) remaining.remove(old_name) if not remaining: break assert not remaining and len(self.return_layers) == len(return_layers), \ f'Return layers ({remaining}) are not present in model' - super(LayerGetterDict, self).__init__(layers) + self.update(layers) - def forward(self, x) -> Dict[Any, torch.Tensor]: + def _collect(self, x) -> (Dict[str, torch.Tensor]): out = OrderedDict() for name, module in self.items(): x = module(x) @@ -213,131 +204,74 @@ class LayerGetterDict(nn.ModuleDict): out[out_id] = x return out + def forward(self, x) -> Dict[str, torch.Tensor]: + return self._collect(x) -class LayerGetterList(nn.Sequential): - """ - Module wrapper that returns intermediate layers from a model as a list - - Originally based on concepts from IntermediateLayerGetter at - https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py - - It has a strong assumption that the modules have been registered into the model in the same - order as they are used. This means that one should **not** reuse the same nn.Module twice - in the forward if you want this to work. - Additionally, it is only able to query submodules that are directly assigned to the model - class (`model.feature1`) or at most one Sequential container deep (`model.features.1`) so - long as `features` is a sequential container assigned to the model and flatten_sequent=True. - - All Sequential containers that are directly assigned to the original model will have their - modules assigned to this module with the name `model.features.1` being changed to `model.features_1` - - Arguments: - model (nn.Module): model on which we will extract the features - return_layers (Dict[name, new_name]): a dict containing the names - of the modules for which the activations will be returned as - the key of the dict, and the value of the dict is the name - of the returned activation (which the user can specify). - concat (bool): whether to concatenate intermediate features that are lists or tuples - vs select element [0] - flatten_sequential (bool): whether to flatten sequential modules assigned to model +class FeatureListNet(FeatureDictNet): + """ Feature extractor with list return - """ - - def __init__(self, model, return_layers, concat=False, flatten_sequential=False): - super(LayerGetterList, self).__init__() - self.return_layers = {} - self.concat = concat - modules = _module_list(model, flatten_sequential=flatten_sequential) - remaining = set(return_layers.keys()) - for new_name, orig_name, module in modules: - self.add_module(new_name, module) - if orig_name in remaining: - self.return_layers[new_name] = return_layers[orig_name] - remaining.remove(orig_name) - if not remaining: - break - assert not remaining and len(self.return_layers) == len(return_layers), \ - f'Return layers ({remaining}) are not present in model' - - def forward(self, x) -> List[torch.Tensor]: - out = [] - for name, module in self.named_children(): - x = module(x) - if name in self.return_layers: - if isinstance(x, (tuple, list)): - # If model tap is a tuple or list, concat or select first element - # FIXME this may need to be more generic / flexible for some nets - out.append(torch.cat(x, 1) if self.concat else x[0]) - else: - out.append(x) - return out - - -def _resolve_feature_info(net, out_indices, feature_info=None): - if feature_info is None: - feature_info = getattr(net, 'feature_info') - if isinstance(feature_info, FeatureInfo): - return feature_info.from_other(out_indices) - elif isinstance(feature_info, (list, tuple)): - return FeatureInfo(net.feature_info, out_indices) - else: - assert False, "Provided feature_info is not valid" - - -def _get_return_layers(feature_info, out_map): - module_names = feature_info.module_name() - return_layers = {} - for i, name in enumerate(module_names): - return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] - return return_layers - - -class FeatureNet(nn.Module): - """ FeatureNet - - Wrap a model and extract features as specified by the out indices, the network - is partially re-built from contained modules using the LayerGetters. - - Please read the docstrings of the LayerGetter classes, they will not work on all models. + See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. + In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. """ def __init__( - self, net, - out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, use_hooks=False, - feature_info=None, feature_concat=False, flatten_sequential=False): - super(FeatureNet, self).__init__() - self.feature_info = _resolve_feature_info(net, out_indices, feature_info) - if use_hooks: - self.body = LayerGetterHooks(net, self.feature_info, out_as_dict=out_as_dict, out_map=out_map) - else: - return_layers = _get_return_layers(self.feature_info, out_map) - lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential) - self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args) + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureListNet, self).__init__( + model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, + flatten_sequential=flatten_sequential) - def forward(self, x): - output = self.body(x) - return output + def forward(self, x) -> (List[torch.Tensor]): + return list(self._collect(x).values()) -class FeatureHookNet(nn.Module): +class FeatureHookNet(nn.ModuleDict): """ FeatureHookNet - Wrap a model and extract features specified by the out indices. + Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. - Features are extracted via hooks without modifying the underlying network in any way. If only - part of the model is used it is up to the caller to remove unneeded layers as this wrapper - does not rewrite and remove unused top-level modules like FeatureNet with LayerGetter. + If `no_rewrite` is True, features are extracted via hooks without modifying the underlying + network in any way. + + If `no_rewrite` is False, the model will be re-written as in the + FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. + + FIXME this does not currently work with Torchscript, see FeatureHooks class """ def __init__( - self, net, - out_indices=(0, 1, 2, 3, 4), out_as_dict=False, out_map=None, - feature_info=None, feature_concat=False): + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, + feature_concat=False, flatten_sequential=False, default_hook_type='forward'): super(FeatureHookNet, self).__init__() - self.feature_info = _resolve_feature_info(net, out_indices, feature_info) - self.body = net - self.hooks = FeatureHooks( - self.feature_info, self.body.named_modules(), out_as_dict=out_as_dict, out_map=out_map) + assert not torch.jit.is_scripting() + self.feature_info = _get_feature_info(model, out_indices) + self.out_as_dict = out_as_dict + layers = OrderedDict() + hooks = [] + if no_rewrite: + assert not flatten_sequential + if hasattr(model, 'reset_classifier'): # make sure classifier is removed? + model.reset_classifier(0) + layers['body'] = model + hooks.extend(self.feature_info) + else: + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type + for f in self.feature_info} + for new_name, old_name, module in modules: + layers[new_name] = module + for fn, fm in module.named_modules(prefix=old_name): + if fn in remaining: + hooks.append(dict(module=fn, hook_type=remaining[fn])) + del remaining[fn] + if not remaining: + break + assert not remaining, f'Return layers ({remaining}) are not present in model' + self.update(layers) + self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) def forward(self, x): - self.body(x) - return self.hooks.get_output(x.device) + for name, module in self.items(): + x = module(x) + out = self.hooks.get_output(x.device) + return out if self.out_as_dict else list(out.values()) diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index da12bf64..aaf5fc1f 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -252,7 +252,7 @@ class Xception65(nn.Module): def _create_gluon_xception(variant, pretrained=False, **kwargs): return build_model_with_cfg( Xception65, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(use_hooks=True), **kwargs) + feature_cfg=dict(feature_cls='hook'), **kwargs) @register_model diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 593b7df5..a34593ce 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo -from .features import FeatureNet, FeatureHookNet +from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .layers import Conv2dSame @@ -234,15 +234,15 @@ def build_model_with_cfg( filter_fn=pretrained_filter_fn, strict=pretrained_strict) if features: - feature_cls = feature_cfg.pop('feature_cls', FeatureNet) - if isinstance(feature_cls, str): - feature_cls = feature_cls.lower() - if feature_cls == 'hook' or feature_cls == 'featurehooknet': - feature_cls = FeatureHookNet - else: - assert False, f'Unknown feature class {feature_cls}' - if feature_cls == FeatureHookNet and hasattr(model, 'reset_classifier'): - model.reset_classifier(0) + feature_cls = FeatureListNet + if 'feature_cls' in feature_cfg: + feature_cls = feature_cfg.pop('feature_cls') + if isinstance(feature_cls, str): + feature_cls = feature_cls.lower() + if 'hook' in feature_cls: + feature_cls = FeatureHookNet + else: + assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) return model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index b99f4f7a..f8e3d738 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -211,7 +211,8 @@ class MobileNetV3Features(nn.Module): return features else: self.blocks(x) - return self.feature_hooks.get_output(x.device) + out = self.feature_hooks.get_output(x.device) + return list(out.values()) def _create_mnv3(model_kwargs, variant, pretrained=False): @@ -220,6 +221,7 @@ def _create_mnv3(model_kwargs, variant, pretrained=False): model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_features', 0) model_kwargs.pop('head_conv', None) + model_kwargs.pop('head_bias', None) model_cls = MobileNetV3Features else: load_strict = True diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 27c59ecd..d682b46b 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -554,7 +554,7 @@ class NASNetALarge(nn.Module): def _create_nasnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( NASNetALarge, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet + feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model **kwargs) diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index e5f3b6d5..5a283ba9 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -337,7 +337,7 @@ class PNASNet5Large(nn.Module): def _create_pnasnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( PNASNet5Large, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet + feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model **kwargs) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 7161f723..6b541e95 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -74,6 +74,29 @@ class SequentialList(nn.Sequential): return x +class SelectSeq(nn.Module): + def __init__(self, mode='index', index=0): + super(SelectSeq, self).__init__() + self.mode = mode + self.index = index + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (Tuple[torch.Tensor]) -> (torch.Tensor) + pass + + def forward(self, x) -> torch.Tensor: + if self.mode == 'index': + return x[self.index] + else: + return torch.cat(x, dim=1) + + def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1): if padding is None: padding = ((stride - 1) + dilation * (k - 1)) // 2 @@ -137,8 +160,10 @@ class SelecSLS(nn.Module): self.stem = conv_bn(in_chans, 32, stride=2) self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']]) + self.from_seq = SelectSeq() # from List[tensor] -> Tensor in module compatible way self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']]) self.num_features = cfg['num_features'] + self.feature_info = cfg['feature_info'] self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -165,7 +190,7 @@ class SelecSLS(nn.Module): def forward_features(self, x): x = self.stem(x) x = self.features(x) - x = self.head(x[0]) + x = self.head(self.from_seq(x)) return x def forward(self, x): @@ -297,6 +322,7 @@ def _create_selecsls(variant, pretrained, model_kwargs): ]) else: raise ValueError('Invalid net configuration ' + variant + ' !!!') + cfg['feature_info'] = feature_info # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? return build_model_with_cfg( diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 27e604b8..50fc6f48 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -160,6 +160,9 @@ class Bottleneck(nn.Module): conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3), aa_layer(channels=planes, filt_size=3, stride=2)) + reduce_layer_planes = max(planes * self.expansion // 8, 64) + self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None + self.conv3 = conv2d_iabn( planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") @@ -167,9 +170,6 @@ class Bottleneck(nn.Module): self.downsample = downsample self.stride = stride - reduce_layer_planes = max(planes * self.expansion // 8, 64) - self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None - def forward(self, x): if self.downsample is not None: residual = self.downsample(x) @@ -225,8 +225,8 @@ class TResNet(nn.Module): dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D? dict(num_chs=self.planes, reduction=4, module='body.layer1'), dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'), - dict(num_chs=self.planes * 4, reduction=16, module='body.layer3'), - dict(num_chs=self.planes * 8, reduction=32, module='body.layer4'), + dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'), + dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'), ] # head diff --git a/timm/models/xception.py b/timm/models/xception.py index 28a78344..db506828 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -228,7 +228,7 @@ class Xception(nn.Module): def _xception(variant, pretrained=False, **kwargs): return build_model_with_cfg( Xception, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(use_hooks=True), **kwargs) + feature_cfg=dict(feature_cls='hook'), **kwargs) @register_model diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 81334027..75ba7a27 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -174,7 +174,7 @@ class XceptionAligned(nn.Module): def _xception(variant, pretrained=False, **kwargs): return build_model_with_cfg( XceptionAligned, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(flatten_sequential=True, use_hooks=True), **kwargs) + feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), **kwargs) @register_model diff --git a/validate.py b/validate.py index 576567bd..6f5f76d1 100755 --- a/validate.py +++ b/validate.py @@ -100,6 +100,8 @@ def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher + if args.legacy_jit: + set_jit_legacy() # create model model = create_model( @@ -119,8 +121,6 @@ def validate(args): model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: - if args.legacy_jit: - set_jit_legacy() torch.jit.optimized_execution(True) model = torch.jit.script(model)