Fix bug introduced in non layer_decay weight_decay application. Remove debug print, fix arg desc.

pull/1239/head
Ross Wightman 3 years ago
parent 372ad5fa0d
commit 0557c8257d

@ -660,7 +660,6 @@ def group_with_matcher(
for k in sorted(filter(lambda x: x is not None, grouping.keys())): for k in sorted(filter(lambda x: x is not None, grouping.keys())):
if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
lid += 1 lid += 1
print(lid, k, grouping[k])
layer_id_to_param[lid].extend(grouping[k]) layer_id_to_param[lid].extend(grouping[k])
if reverse: if reverse:

@ -44,7 +44,7 @@ def param_groups_weight_decay(
if not param.requires_grad: if not param.requires_grad:
continue continue
if param.ndim or name.endswith(".bias") or name in no_weight_decay_list: if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
no_decay.append(param) no_decay.append(param)
else: else:
decay.append(param) decay.append(param)

@ -140,7 +140,7 @@ parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
parser.add_argument('--clip-mode', type=str, default='norm', parser.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")') help='Gradient clipping mode. One of ("norm", "value", "agc")')
parser.add_argument('--layer-decay', type=float, default=None, parser.add_argument('--layer-decay', type=float, default=None,
help='weight decay (default: None)') help='layer-wise learning rate decay (default: None)')
# Learning rate schedule parameters # Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',

Loading…
Cancel
Save