Post merge cleanup, restore previous unwrap fn

pull/1239/head
Ross Wightman 3 years ago
parent 3b6ba76126
commit 690f31d02d

@ -7,33 +7,38 @@ import fnmatch
import torch import torch
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
from .model_ema import ModelEma
_SUB_MODULE_ATTR = ('module', 'model')
def unwrap_model(model):
if isinstance(model, ModelEma): def unwrap_model(model, recursive=True):
return unwrap_model(model.ema) for attr in _SUB_MODULE_ATTR:
else: sub_module = getattr(model, attr, None)
return model.module if hasattr(model, 'module') else model if sub_module is not None:
return unwrap_model(sub_module) if recursive else sub_module
return model
def get_state_dict(model, unwrap_fn=unwrap_model): def get_state_dict(model, unwrap_fn=unwrap_model):
return unwrap_fn(model).state_dict() return unwrap_fn(model).state_dict()
def avg_sq_ch_mean(model, input, output): def avg_sq_ch_mean(model, input, output):
"calculate average channel square mean of output activations" """ calculate average channel square mean of output activations
return torch.mean(output.mean(axis=[0,2,3])**2).item() """
return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item()
def avg_ch_var(model, input, output): def avg_ch_var(model, input, output):
"calculate average channel variance of output activations" """calculate average channel variance of output activations
return torch.mean(output.var(axis=[0,2,3])).item()\ """
return torch.mean(output.var(axis=[0, 2, 3])).item()
def avg_ch_var_residual(model, input, output): def avg_ch_var_residual(model, input, output):
"calculate average channel variance of output activations" """calculate average channel variance of output activations
return torch.mean(output.var(axis=[0,2,3])).item() """
return torch.mean(output.var(axis=[0, 2, 3])).item()
class ActivationStatsHook: class ActivationStatsHook:
@ -62,15 +67,16 @@ class ActivationStatsHook:
raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \ raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \
their lengths are different.") their lengths are different.")
self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns) self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns): for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
self.register_hook(hook_fn_loc, hook_fn) self.register_hook(hook_fn_loc, hook_fn)
def _create_hook(self, hook_fn): def _create_hook(self, hook_fn):
def append_activation_stats(module, input, output): def append_activation_stats(module, input, output):
out = hook_fn(module, input, output) out = hook_fn(module, input, output)
self.stats[hook_fn.__name__].append(out) self.stats[hook_fn.__name__].append(out)
return append_activation_stats return append_activation_stats
def register_hook(self, hook_fn_loc, hook_fn): def register_hook(self, hook_fn_loc, hook_fn):
for name, module in self.model.named_modules(): for name, module in self.model.named_modules():
if not fnmatch.fnmatch(name, hook_fn_loc): if not fnmatch.fnmatch(name, hook_fn_loc):
@ -78,17 +84,18 @@ class ActivationStatsHook:
module.register_forward_hook(self._create_hook(hook_fn)) module.register_forward_hook(self._create_hook(hook_fn))
def extract_spp_stats(model, def extract_spp_stats(
hook_fn_locs, model,
hook_fns, hook_fn_locs,
input_shape=[8, 3, 224, 224]): hook_fns,
input_shape=[8, 3, 224, 224]):
"""Extract average square channel mean and variance of activations during """Extract average square channel mean and variance of activations during
forward pass to plot Signal Propogation Plots (SPP). forward pass to plot Signal Propogation Plots (SPP).
Paper: https://arxiv.org/abs/2101.08692 Paper: https://arxiv.org/abs/2101.08692
Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950
""" """
x = torch.normal(0., 1., input_shape) x = torch.normal(0., 1., input_shape)
hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
_ = model(x) _ = model(x)
@ -186,7 +193,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
named_modules = submodules named_modules = submodules
submodules = [root_module.get_submodule(m) for m in submodules] submodules = [root_module.get_submodule(m) for m in submodules]
if not(len(submodules)): if not (len(submodules)):
named_modules, submodules = list(zip(*root_module.named_children())) named_modules, submodules = list(zip(*root_module.named_children()))
for n, m in zip(named_modules, submodules): for n, m in zip(named_modules, submodules):
@ -201,13 +208,14 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
module.get_submodule(split[0]).add_module(split[1], submodule) module.get_submodule(split[0]).add_module(split[1], submodule)
else: else:
module.add_module(name, submodule) module.add_module(name, submodule)
# Freeze batch norm # Freeze batch norm
if mode == 'freeze': if mode == 'freeze':
res = freeze_batch_norm_2d(m) res = freeze_batch_norm_2d(m)
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
# convert it in place, but will return the converted result. In this case `res` holds the converted # convert it in place, but will return the converted result. In this case `res` holds the converted
# result and we may try to re-assign the named module # result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
_add_submodule(root_module, n, res) _add_submodule(root_module, n, res)
# Unfreeze batch norm # Unfreeze batch norm
else: else:

Loading…
Cancel
Save