Remove layer-decay print

pull/1467/head
Ross Wightman 2 years ago
parent e069249a2d
commit 33e30f8c8b

@ -1,7 +1,7 @@
""" Optimizer Factory w/ Custom Weight Decay """ Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2021 Ross Wightman
""" """
import json import logging
from itertools import islice from itertools import islice
from typing import Optional, Callable, Tuple from typing import Optional, Callable, Tuple
@ -31,6 +31,8 @@ try:
except ImportError: except ImportError:
has_apex = False has_apex = False
_logger = logging.getLogger(__name__)
def param_groups_weight_decay( def param_groups_weight_decay(
model: nn.Module, model: nn.Module,
@ -92,6 +94,7 @@ def param_groups_layer_decay(
no_weight_decay_list: Tuple[str] = (), no_weight_decay_list: Tuple[str] = (),
layer_decay: float = .75, layer_decay: float = .75,
end_layer_decay: Optional[float] = None, end_layer_decay: Optional[float] = None,
verbose: bool = False,
): ):
""" """
Parameter groups for layer-wise lr decay & weight decay Parameter groups for layer-wise lr decay & weight decay
@ -142,8 +145,9 @@ def param_groups_layer_decay(
param_group_names[group_name]["param_names"].append(name) param_group_names[group_name]["param_names"].append(name)
param_groups[group_name]["params"].append(param) param_groups[group_name]["params"].append(param)
# FIXME temporary output to debug new feature if verbose:
print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) import json
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
return list(param_groups.values()) return list(param_groups.values())

Loading…
Cancel
Save