|
|
@ -13,6 +13,7 @@ import os
|
|
|
|
import hashlib
|
|
|
|
import hashlib
|
|
|
|
import shutil
|
|
|
|
import shutil
|
|
|
|
from collections import OrderedDict
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
from timm.models.helpers import load_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
|
|
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
|
|
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
|
|
@ -37,17 +38,8 @@ def main():
|
|
|
|
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
|
|
|
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
|
|
|
if args.checkpoint and os.path.isfile(args.checkpoint):
|
|
|
|
if args.checkpoint and os.path.isfile(args.checkpoint):
|
|
|
|
print("=> Loading checkpoint '{}'".format(args.checkpoint))
|
|
|
|
print("=> Loading checkpoint '{}'".format(args.checkpoint))
|
|
|
|
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
|
|
|
state_dict = load_state_dict(args.checkpoint, use_ema=args.use_ema)
|
|
|
|
|
|
|
|
new_state_dict = {}
|
|
|
|
new_state_dict = OrderedDict()
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict):
|
|
|
|
|
|
|
|
state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict'
|
|
|
|
|
|
|
|
if state_dict_key in checkpoint:
|
|
|
|
|
|
|
|
state_dict = checkpoint[state_dict_key]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
state_dict = checkpoint
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
assert False
|
|
|
|
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
if args.clean_aux_bn and 'aux_bn' in k:
|
|
|
|
if args.clean_aux_bn and 'aux_bn' in k:
|
|
|
|
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
|
|
|
|
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
|
|
|
|