diff --git a/timm/models/_features.py b/timm/models/_features.py index 59b080cd..8e0b8984 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -11,10 +11,11 @@ Hacked together by / Copyright 2020 Ross Wightman from collections import OrderedDict, defaultdict from copy import deepcopy from functools import partial -from typing import Dict, List, Tuple +from typing import Dict, List, Sequence, Tuple, Union import torch import torch.nn as nn +from torch.utils.checkpoint import checkpoint __all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] @@ -88,12 +89,20 @@ class FeatureHooks: """ Feature Hook Helper This module helps with the setup and extraction of hooks for extracting features from - internal nodes in a model by node name. This works quite well in eager Python but needs - redesign for torchscript. + internal nodes in a model by node name. + + FIXME This works well in eager Python but needs redesign for torchscript. """ - def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): + def __init__( + self, + hooks: Sequence[str], + named_modules: dict, + out_map: Sequence[Union[int, str]] = None, + default_hook_type: str = 'forward', + ): # setup feature hooks + self._feature_outputs = defaultdict(OrderedDict) modules = {k: v for k, v in named_modules} for i, h in enumerate(hooks): hook_name = h['module'] @@ -107,7 +116,6 @@ class FeatureHooks: m.register_forward_hook(hook_fn) else: assert False, "Unsupported hook type" - self._feature_outputs = defaultdict(OrderedDict) def _collect_output_hook(self, hook_id, *args): x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre @@ -167,23 +175,30 @@ class FeatureDictNet(nn.ModuleDict): one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. All Sequential containers that are directly assigned to the original model will have their modules assigned to this module with the name `model.features.1` being changed to `model.features_1` - - Arguments: - model (nn.Module): model from which we will extract the features - out_indices (tuple[int]): model output indices to extract features for - out_map (sequence): list or tuple specifying desired return id for each out index, - otherwise str(index) is used - feature_concat (bool): whether to concatenate intermediate features that are lists or tuples - vs select element [0] - flatten_sequential (bool): whether to flatten sequential modules assigned to model """ def __init__( - self, model, - out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + self, + model: nn.Module, + out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + out_map: Sequence[Union[int, str]] = None, + feature_concat: bool = False, + flatten_sequential: bool = False, + ): + """ + Args: + model: Model from which to extract features. + out_indices: Output indices of the model features to extract. + out_map: Return id mapping for each output index, otherwise str(index) is used. + feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting + first element e.g. `x[0]` + flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules) + """ super(FeatureDictNet, self).__init__() self.feature_info = _get_feature_info(model, out_indices) self.concat = feature_concat + self.grad_checkpointing = False self.return_layers = {} + return_layers = _get_return_layers(self.feature_info, out_map) modules = _module_list(model, flatten_sequential=flatten_sequential) remaining = set(return_layers.keys()) @@ -200,10 +215,21 @@ class FeatureDictNet(nn.ModuleDict): f'Return layers ({remaining}) are not present in model' self.update(layers) + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + def _collect(self, x) -> (Dict[str, torch.Tensor]): out = OrderedDict() - for name, module in self.items(): - x = module(x) + for i, (name, module) in enumerate(self.items()): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # Skipping checkpoint of first module because need a gradient at input + # Skipping last because networks with in-place ops might fail w/ checkpointing enabled + # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues + first_or_last_module = i == 0 or i == max(len(self) - 1, 0) + x = module(x) if first_or_last_module else checkpoint(module, x) + else: + x = module(x) + if name in self.return_layers: out_id = self.return_layers[name] if isinstance(x, (tuple, list)): @@ -221,15 +247,29 @@ class FeatureDictNet(nn.ModuleDict): class FeatureListNet(FeatureDictNet): """ Feature extractor with list return - See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. - In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. + A specialization of FeatureDictNet that always returns features as a list (values() of dict). """ def __init__( - self, model, - out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + self, + model: nn.Module, + out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + feature_concat: bool = False, + flatten_sequential: bool = False, + ): + """ + Args: + model: Model from which to extract features. + out_indices: Output indices of the model features to extract. + feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting + first element e.g. `x[0]` + flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules) + """ super(FeatureListNet, self).__init__( - model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, - flatten_sequential=flatten_sequential) + model, + out_indices=out_indices, + feature_concat=feature_concat, + flatten_sequential=flatten_sequential, + ) def forward(self, x) -> (List[torch.Tensor]): return list(self._collect(x).values()) @@ -249,13 +289,33 @@ class FeatureHookNet(nn.ModuleDict): FIXME this does not currently work with Torchscript, see FeatureHooks class """ def __init__( - self, model, - out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, - feature_concat=False, flatten_sequential=False, default_hook_type='forward'): + self, + model: nn.Module, + out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + out_map: Sequence[Union[int, str]] = None, + out_as_dict: bool = False, + no_rewrite: bool = False, + flatten_sequential: bool = False, + default_hook_type: str = 'forward', + ): + """ + + Args: + model: Model from which to extract features. + out_indices: Output indices of the model features to extract. + out_map: Return id mapping for each output index, otherwise str(index) is used. + out_as_dict: Output features as a dict. + no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed. + flatten_sequential arg must also be False if this is set True. + flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers. + default_hook_type: The default hook type to use if not specified in model.feature_info. + """ super(FeatureHookNet, self).__init__() assert not torch.jit.is_scripting() self.feature_info = _get_feature_info(model, out_indices) self.out_as_dict = out_as_dict + self.grad_checkpointing = False + layers = OrderedDict() hooks = [] if no_rewrite: @@ -266,8 +326,10 @@ class FeatureHookNet(nn.ModuleDict): hooks.extend(self.feature_info.get_dicts()) else: modules = _module_list(model, flatten_sequential=flatten_sequential) - remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type - for f in self.feature_info.get_dicts()} + remaining = { + f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type + for f in self.feature_info.get_dicts() + } for new_name, old_name, module in modules: layers[new_name] = module for fn, fm in module.named_modules(prefix=old_name): @@ -280,8 +342,18 @@ class FeatureHookNet(nn.ModuleDict): self.update(layers) self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + def forward(self, x): - for name, module in self.items(): - x = module(x) + for i, (name, module) in enumerate(self.items()): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # Skipping checkpoint of first module because need a gradient at input + # Skipping last because networks with in-place ops might fail w/ checkpointing enabled + # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues + first_or_last_module = i == 0 or i == max(len(self) - 1, 0) + x = module(x) if first_or_last_module else checkpoint(module, x) + else: + x = module(x) out = self.hooks.get_output(x.device) return out if self.out_as_dict else list(out.values()) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index a3866fec..83ecbb1c 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -41,6 +41,7 @@ from typing import List import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct @@ -211,6 +212,7 @@ class EfficientNetFeatures(nn.Module): norm_act_layer = get_norm_act_layer(norm_layer, act_layer) se_layer = se_layer or SqueezeExcite self.drop_rate = drop_rate + self.grad_checkpointing = False # Stem if not fix_stem: @@ -241,6 +243,10 @@ class EfficientNetFeatures(nn.Module): hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + def forward(self, x) -> List[torch.Tensor]: x = self.conv_stem(x) x = self.bn1(x) @@ -249,7 +255,10 @@ class EfficientNetFeatures(nn.Module): if 0 in self._stage_out_idx: features.append(x) # add stem out for i, b in enumerate(self.blocks): - x = b(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(b, x) + else: + x = b(x) if i + 1 in self._stage_out_idx: features.append(x) return features diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index e1da91a2..5943781f 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -12,6 +12,7 @@ from typing import List import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer @@ -188,6 +189,7 @@ class MobileNetV3Features(nn.Module): norm_layer = norm_layer or nn.BatchNorm2d se_layer = se_layer or SqueezeExcite self.drop_rate = drop_rate + self.grad_checkpointing = False # Stem if not fix_stem: @@ -220,6 +222,10 @@ class MobileNetV3Features(nn.Module): hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + def forward(self, x) -> List[torch.Tensor]: x = self.conv_stem(x) x = self.bn1(x) @@ -229,7 +235,10 @@ class MobileNetV3Features(nn.Module): if 0 in self._stage_out_idx: features.append(x) # add stem out for i, b in enumerate(self.blocks): - x = b(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(b, x) + else: + x = b(x) if i + 1 in self._stage_out_idx: features.append(x) return features