|
|
@ -21,7 +21,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|
|
|
help='output path')
|
|
|
|
help='output path')
|
|
|
|
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
|
|
|
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
|
|
|
help='use ema version of weights if present')
|
|
|
|
help='use ema version of weights if present')
|
|
|
|
|
|
|
|
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
|
|
|
|
|
|
|
|
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
|
|
|
|
|
|
|
|
|
|
|
|
_TEMP_NAME = './_checkpoint.pth'
|
|
|
|
_TEMP_NAME = './_checkpoint.pth'
|
|
|
|
|
|
|
|
|
|
|
@ -48,6 +49,10 @@ def main():
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
assert False
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
|
|
|
|
if args.clean_aux_bn and 'aux_bn' in k:
|
|
|
|
|
|
|
|
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
|
|
|
|
|
|
|
|
# load with the unmodified model using BatchNorm2d.
|
|
|
|
|
|
|
|
continue
|
|
|
|
name = k[7:] if k.startswith('module') else k
|
|
|
|
name = k[7:] if k.startswith('module') else k
|
|
|
|
new_state_dict[name] = v
|
|
|
|
new_state_dict[name] = v
|
|
|
|
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
|
|
|
|
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
|
|
|
|