Add support to clean_checkpoint.py to remove aux_bn weights/biases from SplitBatchNorm

pull/82/head
Ross Wightman 4 years ago
parent 2a88412413
commit cc0b1f4130

@ -21,7 +21,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
help='output path')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
_TEMP_NAME = './_checkpoint.pth'
@ -48,6 +49,10 @@ def main():
else:
assert False
for k, v in state_dict.items():
if args.clean_aux_bn and 'aux_bn' in k:
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
# load with the unmodified model using BatchNorm2d.
continue
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(args.checkpoint))

Loading…
Cancel
Save