update to work with fnmatch

pull/525/head
Aman Arora 4 years ago
parent 20626e8387
commit b85be24054

@ -4,7 +4,7 @@ Hacked together by / Copyright 2020 Ross Wightman
""" """
from .model_ema import ModelEma from .model_ema import ModelEma
import torch import torch
import fnmatch
def unwrap_model(model): def unwrap_model(model):
if isinstance(model, ModelEma): if isinstance(model, ModelEma):
@ -23,32 +23,37 @@ def avg_sq_ch_mean(model, input, output):
def avg_ch_var(model, input, output): def avg_ch_var(model, input, output):
"calculate average channel variance of output activations"
return torch.mean(output.var(axis=[0,2,3])).item()\
def avg_ch_var_residual(model, input, output):
"calculate average channel variance of output activations" "calculate average channel variance of output activations"
return torch.mean(output.var(axis=[0,2,3])).item() return torch.mean(output.var(axis=[0,2,3])).item()
class ActivationStatsHook: class ActivationStatsHook:
"""Iterates through each of `model`'s modules and if module's class name """Iterates through each of `model`'s modules and matches modules using unix pattern
is present in `layer_names` then registers `hook_fns` inside that module matching based on `layer_name` and `layer_type`. If there is match, this class adds
and stores activation stats inside `self.stats`. creates a hook using `hook_fn` and adds it to the module.
Arguments: Arguments:
model (nn.Module): model from which we will extract the activation stats model (nn.Module): model from which we will extract the activation stats
layer_names (List[str]): The layer name to look for to register forward layer_names (str): The layer name to look for to register forward
hook. Example, `BasicBlock`, `Bottleneck` hook. Example, 'stem', 'stages'
hook_fns (List[Callable]): List of hook functions to be registered at every hook_fns (List[Callable]): List of hook functions to be registered at every
module in `layer_names`. module in `layer_names`.
Inspiration from https://docs.fast.ai/callback.hook.html. Inspiration from https://docs.fast.ai/callback.hook.html.
""" """
def __init__(self, model, layer_names, hook_fns=[avg_sq_ch_mean, avg_ch_var]): def __init__(self, model, hook_fn_locs, hook_fns):
self.model = model self.model = model
self.layer_names = layer_names self.hook_fn_locs = hook_fn_locs
self.hook_fns = hook_fns self.hook_fns = hook_fns
self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns) self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
for hook_fn in hook_fns: for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
self.register_hook(layer_names, hook_fn) self.register_hook(hook_fn_loc, hook_fn)
def _create_hook(self, hook_fn): def _create_hook(self, hook_fn):
def append_activation_stats(module, input, output): def append_activation_stats(module, input, output):
@ -56,17 +61,16 @@ class ActivationStatsHook:
self.stats[hook_fn.__name__].append(out) self.stats[hook_fn.__name__].append(out)
return append_activation_stats return append_activation_stats
def register_hook(self, layer_names, hook_fn): def register_hook(self, hook_fn_loc, hook_fn):
for layer in self.model.modules(): for name, module in self.model.named_modules():
layer_name = layer.__class__.__name__ if not fnmatch.fnmatch(name, hook_fn_loc):
if layer_name not in layer_names:
continue continue
layer.register_forward_hook(self._create_hook(hook_fn)) module.register_forward_hook(self._create_hook(hook_fn))
def extract_spp_stats(model, def extract_spp_stats(model,
layer_names, hook_fn_locs,
hook_fns=[avg_sq_ch_mean, avg_ch_var], hook_fns,
input_shape=[8, 3, 224, 224]): input_shape=[8, 3, 224, 224]):
"""Extract average square channel mean and variance of activations during """Extract average square channel mean and variance of activations during
forward pass to plot Signal Propogation Plots (SPP). forward pass to plot Signal Propogation Plots (SPP).
@ -74,7 +78,7 @@ def extract_spp_stats(model,
Paper: https://arxiv.org/abs/2101.08692 Paper: https://arxiv.org/abs/2101.08692
""" """
x = torch.normal(0., 1., input_shape) x = torch.normal(0., 1., input_shape)
hook = ActivationStatsHook(model, layer_names, hook_fns) hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
_ = model(x) _ = model(x)
return hook.stats return hook.stats
Loading…
Cancel
Save