Add grad_checkpointing support to features_only, test in EfficientDet.

Ross Wightman 1 year ago
parent 45af496197
commit 2cfff0581b

@ -11,10 +11,11 @@ Hacked together by / Copyright 2020 Ross Wightman
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Tuple
from typing import Dict, List, Sequence, Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
@ -88,12 +89,20 @@ class FeatureHooks:
""" Feature Hook Helper
This module helps with the setup and extraction of hooks for extracting features from
internal nodes in a model by node name. This works quite well in eager Python but needs
redesign for torchscript.
internal nodes in a model by node name.
FIXME This works well in eager Python but needs redesign for torchscript.
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
def __init__(
hooks: Sequence[str],
named_modules: dict,
out_map: Sequence[Union[int, str]] = None,
default_hook_type: str = 'forward',
# setup feature hooks
self._feature_outputs = defaultdict(OrderedDict)
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
hook_name = h['module']
@ -107,7 +116,6 @@ class FeatureHooks:
assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
@ -167,23 +175,30 @@ class FeatureDictNet(nn.ModuleDict):
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
model (nn.Module): model from which we will extract the features
out_indices (tuple[int]): model output indices to extract features for
out_map (sequence): list or tuple specifying desired return id for each out index,
otherwise str(index) is used
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model
def __init__(
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
feature_concat: bool = False,
flatten_sequential: bool = False,
model: Model from which to extract features.
out_indices: Output indices of the model features to extract.
out_map: Return id mapping for each output index, otherwise str(index) is used.
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
first element e.g. `x[0]`
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
super(FeatureDictNet, self).__init__()
self.feature_info = _get_feature_info(model, out_indices)
self.concat = feature_concat
self.grad_checkpointing = False
self.return_layers = {}
return_layers = _get_return_layers(self.feature_info, out_map)
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys())
@ -200,10 +215,21 @@ class FeatureDictNet(nn.ModuleDict):
f'Return layers ({remaining}) are not present in model'
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable
def _collect(self, x) -> (Dict[str, torch.Tensor]):
out = OrderedDict()
for name, module in self.items():
x = module(x)
for i, (name, module) in enumerate(self.items()):
if self.grad_checkpointing and not torch.jit.is_scripting():
# Skipping checkpoint of first module because need a gradient at input
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
x = module(x) if first_or_last_module else checkpoint(module, x)
x = module(x)
if name in self.return_layers:
out_id = self.return_layers[name]
if isinstance(x, (tuple, list)):
@ -221,15 +247,29 @@ class FeatureDictNet(nn.ModuleDict):
class FeatureListNet(FeatureDictNet):
""" Feature extractor with list return
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
A specialization of FeatureDictNet that always returns features as a list (values() of dict).
def __init__(
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
feature_concat: bool = False,
flatten_sequential: bool = False,
model: Model from which to extract features.
out_indices: Output indices of the model features to extract.
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
first element e.g. `x[0]`
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
super(FeatureListNet, self).__init__(
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
def forward(self, x) -> (List[torch.Tensor]):
return list(self._collect(x).values())
@ -249,13 +289,33 @@ class FeatureHookNet(nn.ModuleDict):
FIXME this does not currently work with Torchscript, see FeatureHooks class
def __init__(
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
out_as_dict: bool = False,
no_rewrite: bool = False,
flatten_sequential: bool = False,
default_hook_type: str = 'forward',
model: Model from which to extract features.
out_indices: Output indices of the model features to extract.
out_map: Return id mapping for each output index, otherwise str(index) is used.
out_as_dict: Output features as a dict.
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
flatten_sequential arg must also be False if this is set True.
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
default_hook_type: The default hook type to use if not specified in model.feature_info.
super(FeatureHookNet, self).__init__()
assert not torch.jit.is_scripting()
self.feature_info = _get_feature_info(model, out_indices)
self.out_as_dict = out_as_dict
self.grad_checkpointing = False
layers = OrderedDict()
hooks = []
if no_rewrite:
@ -266,8 +326,10 @@ class FeatureHookNet(nn.ModuleDict):
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
for f in self.feature_info.get_dicts()}
remaining = {
f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
for f in self.feature_info.get_dicts()
for new_name, old_name, module in modules:
layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name):
@ -280,8 +342,18 @@ class FeatureHookNet(nn.ModuleDict):
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable
def forward(self, x):
for name, module in self.items():
x = module(x)
for i, (name, module) in enumerate(self.items()):
if self.grad_checkpointing and not torch.jit.is_scripting():
# Skipping checkpoint of first module because need a gradient at input
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
x = module(x) if first_or_last_module else checkpoint(module, x)
x = module(x)
out = self.hooks.get_output(x.device)
return out if self.out_as_dict else list(out.values())

@ -41,6 +41,7 @@ from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
@ -211,6 +212,7 @@ class EfficientNetFeatures(nn.Module):
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite
self.drop_rate = drop_rate
self.grad_checkpointing = False
# Stem
if not fix_stem:
@ -241,6 +243,10 @@ class EfficientNetFeatures(nn.Module):
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x)
x = self.bn1(x)
@ -249,7 +255,10 @@ class EfficientNetFeatures(nn.Module):
if 0 in self._stage_out_idx:
features.append(x) # add stem out
for i, b in enumerate(self.blocks):
x = b(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(b, x)
x = b(x)
if i + 1 in self._stage_out_idx:
return features

@ -12,6 +12,7 @@ from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer
@ -188,6 +189,7 @@ class MobileNetV3Features(nn.Module):
norm_layer = norm_layer or nn.BatchNorm2d
se_layer = se_layer or SqueezeExcite
self.drop_rate = drop_rate
self.grad_checkpointing = False
# Stem
if not fix_stem:
@ -220,6 +222,10 @@ class MobileNetV3Features(nn.Module):
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x)
x = self.bn1(x)
@ -229,7 +235,10 @@ class MobileNetV3Features(nn.Module):
if 0 in self._stage_out_idx:
features.append(x) # add stem out
for i, b in enumerate(self.blocks):
x = b(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(b, x)
x = b(x)
if i + 1 in self._stage_out_idx:
return features
