From e5da481073ac4beb634ee9b33e264baa3bee8688 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 6 Oct 2021 17:00:27 -0700 Subject: [PATCH] Small post-merge tweak for freeze/unfreeze, add to __init__ for utils --- timm/utils/__init__.py | 2 +- timm/utils/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index d02e62d2..11de9c9c 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -7,7 +7,7 @@ from .jit import set_jit_legacy from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg -from .model import unwrap_model, get_state_dict +from .model import unwrap_model, get_state_dict, freeze, unfreeze from .model_ema import ModelEma, ModelEmaV2 from .random import random_seed from .summary import update_summary, get_outdir diff --git a/timm/utils/model.py b/timm/utils/model.py index ffe66049..879ac3f8 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -194,7 +194,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, for n, m in zip(named_modules, submodules): # (Un)freeze parameters for p in m.parameters(): - p.requires_grad = (False if mode == 'freeze' else True) + p.requires_grad = False if mode == 'freeze' else True if include_bn_running_stats: # Helper to add submodule specified as a named_module def _add_submodule(module, name, submodule):