Small post-merge tweak for freeze/unfreeze, add to __init__ for utils

pull/910/head
Ross Wightman 3 years ago
parent 5ca72dcc75
commit e5da481073

@ -7,7 +7,7 @@ from .jit import set_jit_legacy
from .log import setup_default_logging, FormatterNoInfo from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg from .misc import natural_key, add_bool_arg
from .model import unwrap_model, get_state_dict from .model import unwrap_model, get_state_dict, freeze, unfreeze
from .model_ema import ModelEma, ModelEmaV2 from .model_ema import ModelEma, ModelEmaV2
from .random import random_seed from .random import random_seed
from .summary import update_summary, get_outdir from .summary import update_summary, get_outdir

@ -194,7 +194,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
for n, m in zip(named_modules, submodules): for n, m in zip(named_modules, submodules):
# (Un)freeze parameters # (Un)freeze parameters
for p in m.parameters(): for p in m.parameters():
p.requires_grad = (False if mode == 'freeze' else True) p.requires_grad = False if mode == 'freeze' else True
if include_bn_running_stats: if include_bn_running_stats:
# Helper to add submodule specified as a named_module # Helper to add submodule specified as a named_module
def _add_submodule(module, name, submodule): def _add_submodule(module, name, submodule):

Loading…
Cancel
Save