Fix some formatting in utils/model.py

pull/933/head
Ross Wightman 3 years ago
parent 0fe4fd3f1f
commit 57992509f9

@ -2,11 +2,9 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from logging import root
from typing import Sequence
import torch
import fnmatch import fnmatch
import torch
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
from .model_ema import ModelEma from .model_ema import ModelEma
@ -23,19 +21,22 @@ def get_state_dict(model, unwrap_fn=unwrap_model):
return unwrap_fn(model).state_dict() return unwrap_fn(model).state_dict()
def avg_sq_ch_mean(model, input, output): def avg_sq_ch_mean(model, input, output):
"calculate average channel square mean of output activations" """ calculate average channel square mean of output activations
return torch.mean(output.mean(axis=[0,2,3])**2).item() """
return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item()
def avg_ch_var(model, input, output): def avg_ch_var(model, input, output):
"calculate average channel variance of output activations" """ calculate average channel variance of output activations
return torch.mean(output.var(axis=[0,2,3])).item()\ """
return torch.mean(output.var(axis=[0, 2, 3])).item()
def avg_ch_var_residual(model, input, output): def avg_ch_var_residual(model, input, output):
"calculate average channel variance of output activations" """ calculate average channel variance of output activations
return torch.mean(output.var(axis=[0,2,3])).item() """
return torch.mean(output.var(axis=[0, 2, 3])).item()
class ActivationStatsHook: class ActivationStatsHook:
@ -64,15 +65,16 @@ class ActivationStatsHook:
raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \ raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \
their lengths are different.") their lengths are different.")
self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns) self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns): for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
self.register_hook(hook_fn_loc, hook_fn) self.register_hook(hook_fn_loc, hook_fn)
def _create_hook(self, hook_fn): def _create_hook(self, hook_fn):
def append_activation_stats(module, input, output): def append_activation_stats(module, input, output):
out = hook_fn(module, input, output) out = hook_fn(module, input, output)
self.stats[hook_fn.__name__].append(out) self.stats[hook_fn.__name__].append(out)
return append_activation_stats return append_activation_stats
def register_hook(self, hook_fn_loc, hook_fn): def register_hook(self, hook_fn_loc, hook_fn):
for name, module in self.model.named_modules(): for name, module in self.model.named_modules():
if not fnmatch.fnmatch(name, hook_fn_loc): if not fnmatch.fnmatch(name, hook_fn_loc):
@ -80,17 +82,18 @@ class ActivationStatsHook:
module.register_forward_hook(self._create_hook(hook_fn)) module.register_forward_hook(self._create_hook(hook_fn))
def extract_spp_stats(model, def extract_spp_stats(
hook_fn_locs, model,
hook_fns, hook_fn_locs,
input_shape=[8, 3, 224, 224]): hook_fns,
input_shape=[8, 3, 224, 224]):
"""Extract average square channel mean and variance of activations during """Extract average square channel mean and variance of activations during
forward pass to plot Signal Propogation Plots (SPP). forward pass to plot Signal Propogation Plots (SPP).
Paper: https://arxiv.org/abs/2101.08692 Paper: https://arxiv.org/abs/2101.08692
Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950
""" """
x = torch.normal(0., 1., input_shape) x = torch.normal(0., 1., input_shape)
hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
_ = model(x) _ = model(x)
@ -188,7 +191,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
named_modules = submodules named_modules = submodules
submodules = [root_module.get_submodule(m) for m in submodules] submodules = [root_module.get_submodule(m) for m in submodules]
if not(len(submodules)): if not len(submodules):
named_modules, submodules = list(zip(*root_module.named_children())) named_modules, submodules = list(zip(*root_module.named_children()))
for n, m in zip(named_modules, submodules): for n, m in zip(named_modules, submodules):
@ -203,13 +206,14 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
module.get_submodule(split[0]).add_module(split[1], submodule) module.get_submodule(split[0]).add_module(split[1], submodule)
else: else:
module.add_module(name, submodule) module.add_module(name, submodule)
# Freeze batch norm # Freeze batch norm
if mode == 'freeze': if mode == 'freeze':
res = freeze_batch_norm_2d(m) res = freeze_batch_norm_2d(m)
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
# convert it in place, but will return the converted result. In this case `res` holds the converted # convert it in place, but will return the converted result. In this case `res` holds the converted
# result and we may try to re-assign the named module # result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
_add_submodule(root_module, n, res) _add_submodule(root_module, n, res)
# Unfreeze batch norm # Unfreeze batch norm
else: else:
@ -267,4 +271,3 @@ def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
See example in docstring for `freeze`. See example in docstring for `freeze`.
""" """
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze") _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze")
Loading…
Cancel
Save