Post merge cleanup, restore previous unwrap fn

pull/1239/head
Ross Wightman 3 years ago
parent 3b6ba76126
commit 690f31d02d

@ -7,33 +7,38 @@ import fnmatch
import torch
from torchvision.ops.misc import FrozenBatchNorm2d
from .model_ema import ModelEma
_SUB_MODULE_ATTR = ('module', 'model')
def unwrap_model(model):
if isinstance(model, ModelEma):
return unwrap_model(model.ema)
else:
return model.module if hasattr(model, 'module') else model
def unwrap_model(model, recursive=True):
for attr in _SUB_MODULE_ATTR:
sub_module = getattr(model, attr, None)
if sub_module is not None:
return unwrap_model(sub_module) if recursive else sub_module
return model
def get_state_dict(model, unwrap_fn=unwrap_model):
return unwrap_fn(model).state_dict()
def avg_sq_ch_mean(model, input, output):
"calculate average channel square mean of output activations"
return torch.mean(output.mean(axis=[0,2,3])**2).item()
def avg_sq_ch_mean(model, input, output):
""" calculate average channel square mean of output activations
"""
return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item()
def avg_ch_var(model, input, output):
"calculate average channel variance of output activations"
return torch.mean(output.var(axis=[0,2,3])).item()\
def avg_ch_var(model, input, output):
"""calculate average channel variance of output activations
"""
return torch.mean(output.var(axis=[0, 2, 3])).item()
def avg_ch_var_residual(model, input, output):
"calculate average channel variance of output activations"
return torch.mean(output.var(axis=[0,2,3])).item()
def avg_ch_var_residual(model, input, output):
"""calculate average channel variance of output activations
"""
return torch.mean(output.var(axis=[0, 2, 3])).item()
class ActivationStatsHook:
@ -62,15 +67,16 @@ class ActivationStatsHook:
raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \
their lengths are different.")
self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
self.register_hook(hook_fn_loc, hook_fn)
def _create_hook(self, hook_fn):
def append_activation_stats(module, input, output):
out = hook_fn(module, input, output)
self.stats[hook_fn.__name__].append(out)
return append_activation_stats
def register_hook(self, hook_fn_loc, hook_fn):
for name, module in self.model.named_modules():
if not fnmatch.fnmatch(name, hook_fn_loc):
@ -78,17 +84,18 @@ class ActivationStatsHook:
module.register_forward_hook(self._create_hook(hook_fn))
def extract_spp_stats(model,
hook_fn_locs,
hook_fns,
input_shape=[8, 3, 224, 224]):
def extract_spp_stats(
model,
hook_fn_locs,
hook_fns,
input_shape=[8, 3, 224, 224]):
"""Extract average square channel mean and variance of activations during
forward pass to plot Signal Propogation Plots (SPP).
Paper: https://arxiv.org/abs/2101.08692
Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950
"""
"""
x = torch.normal(0., 1., input_shape)
hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
_ = model(x)
@ -186,7 +193,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
named_modules = submodules
submodules = [root_module.get_submodule(m) for m in submodules]
if not(len(submodules)):
if not (len(submodules)):
named_modules, submodules = list(zip(*root_module.named_children()))
for n, m in zip(named_modules, submodules):
@ -201,13 +208,14 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
module.get_submodule(split[0]).add_module(split[1], submodule)
else:
module.add_module(name, submodule)
# Freeze batch norm
if mode == 'freeze':
res = freeze_batch_norm_2d(m)
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
# convert it in place, but will return the converted result. In this case `res` holds the converted
# result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
_add_submodule(root_module, n, res)
# Unfreeze batch norm
else:

Loading…
Cancel
Save