|
|
@ -1,7 +1,7 @@
|
|
|
|
""" Optimizer Factory w/ Custom Weight Decay
|
|
|
|
""" Optimizer Factory w/ Custom Weight Decay
|
|
|
|
Hacked together by / Copyright 2021 Ross Wightman
|
|
|
|
Hacked together by / Copyright 2021 Ross Wightman
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
from itertools import islice
|
|
|
|
from itertools import islice
|
|
|
|
from typing import Optional, Callable, Tuple
|
|
|
|
from typing import Optional, Callable, Tuple
|
|
|
|
|
|
|
|
|
|
|
@ -31,6 +31,8 @@ try:
|
|
|
|
except ImportError:
|
|
|
|
except ImportError:
|
|
|
|
has_apex = False
|
|
|
|
has_apex = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def param_groups_weight_decay(
|
|
|
|
def param_groups_weight_decay(
|
|
|
|
model: nn.Module,
|
|
|
|
model: nn.Module,
|
|
|
@ -92,6 +94,7 @@ def param_groups_layer_decay(
|
|
|
|
no_weight_decay_list: Tuple[str] = (),
|
|
|
|
no_weight_decay_list: Tuple[str] = (),
|
|
|
|
layer_decay: float = .75,
|
|
|
|
layer_decay: float = .75,
|
|
|
|
end_layer_decay: Optional[float] = None,
|
|
|
|
end_layer_decay: Optional[float] = None,
|
|
|
|
|
|
|
|
verbose: bool = False,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Parameter groups for layer-wise lr decay & weight decay
|
|
|
|
Parameter groups for layer-wise lr decay & weight decay
|
|
|
@ -142,8 +145,9 @@ def param_groups_layer_decay(
|
|
|
|
param_group_names[group_name]["param_names"].append(name)
|
|
|
|
param_group_names[group_name]["param_names"].append(name)
|
|
|
|
param_groups[group_name]["params"].append(param)
|
|
|
|
param_groups[group_name]["params"].append(param)
|
|
|
|
|
|
|
|
|
|
|
|
# FIXME temporary output to debug new feature
|
|
|
|
if verbose:
|
|
|
|
print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
|
|
|
import json
|
|
|
|
|
|
|
|
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
|
|
|
|
|
|
|
|
|
|
|
return list(param_groups.values())
|
|
|
|
return list(param_groups.values())
|
|
|
|
|
|
|
|
|
|
|
|