Fix some formatting in utils/model.py

pull/933/head
Ross Wightman 3 years ago
parent 0fe4fd3f1f
commit 57992509f9

@ -2,11 +2,9 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from logging import root import fnmatch
from typing import Sequence
import torch import torch
import fnmatch
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
from .model_ema import ModelEma from .model_ema import ModelEma
@ -24,18 +22,21 @@ def get_state_dict(model, unwrap_fn=unwrap_model):
def avg_sq_ch_mean(model, input, output): def avg_sq_ch_mean(model, input, output):
"calculate average channel square mean of output activations" """ calculate average channel square mean of output activations
return torch.mean(output.mean(axis=[0,2,3])**2).item() """
return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item()
def avg_ch_var(model, input, output): def avg_ch_var(model, input, output):
"calculate average channel variance of output activations" """ calculate average channel variance of output activations
return torch.mean(output.var(axis=[0,2,3])).item()\ """
return torch.mean(output.var(axis=[0, 2, 3])).item()
def avg_ch_var_residual(model, input, output): def avg_ch_var_residual(model, input, output):
"calculate average channel variance of output activations" """ calculate average channel variance of output activations
return torch.mean(output.var(axis=[0,2,3])).item() """
return torch.mean(output.var(axis=[0, 2, 3])).item()
class ActivationStatsHook: class ActivationStatsHook:
@ -71,6 +72,7 @@ class ActivationStatsHook:
def append_activation_stats(module, input, output): def append_activation_stats(module, input, output):
out = hook_fn(module, input, output) out = hook_fn(module, input, output)
self.stats[hook_fn.__name__].append(out) self.stats[hook_fn.__name__].append(out)
return append_activation_stats return append_activation_stats
def register_hook(self, hook_fn_loc, hook_fn): def register_hook(self, hook_fn_loc, hook_fn):
@ -80,7 +82,8 @@ class ActivationStatsHook:
module.register_forward_hook(self._create_hook(hook_fn)) module.register_forward_hook(self._create_hook(hook_fn))
def extract_spp_stats(model, def extract_spp_stats(
model,
hook_fn_locs, hook_fn_locs,
hook_fns, hook_fns,
input_shape=[8, 3, 224, 224]): input_shape=[8, 3, 224, 224]):
@ -188,7 +191,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
named_modules = submodules named_modules = submodules
submodules = [root_module.get_submodule(m) for m in submodules] submodules = [root_module.get_submodule(m) for m in submodules]
if not(len(submodules)): if not len(submodules):
named_modules, submodules = list(zip(*root_module.named_children())) named_modules, submodules = list(zip(*root_module.named_children()))
for n, m in zip(named_modules, submodules): for n, m in zip(named_modules, submodules):
@ -203,6 +206,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
module.get_submodule(split[0]).add_module(split[1], submodule) module.get_submodule(split[0]).add_module(split[1], submodule)
else: else:
module.add_module(name, submodule) module.add_module(name, submodule)
# Freeze batch norm # Freeze batch norm
if mode == 'freeze': if mode == 'freeze':
res = freeze_batch_norm_2d(m) res = freeze_batch_norm_2d(m)
@ -267,4 +271,3 @@ def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
See example in docstring for `freeze`. See example in docstring for `freeze`.
""" """
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze") _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze")
Loading…
Cancel
Save