From 8a33a6c90a2b74c50f0129926f1830e5469e5f74 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 13 Apr 2019 14:15:35 -0700 Subject: [PATCH] Add checkpoint clean script, add link to pretrained resnext50 weights --- clean_checkpoint.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ models/resnet.py | 6 ++++-- models/senet.py | 2 +- 3 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 clean_checkpoint.py diff --git a/clean_checkpoint.py b/clean_checkpoint.py new file mode 100644 index 00000000..471630b6 --- /dev/null +++ b/clean_checkpoint.py @@ -0,0 +1,45 @@ +import torch +import argparse +import os +import hashlib +from collections import OrderedDict + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH', + help='output path') + + +def main(): + args = parser.parse_args() + + if os.path.exists(args.output): + print("Error: Output filename ({}) already exists.".format(args.output)) + exit(1) + + # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save + if args.checkpoint and os.path.isfile(args.checkpoint): + print("=> Loading checkpoint '{}'".format(args.checkpoint)) + checkpoint = torch.load(args.checkpoint, map_location='cpu') + + new_state_dict = OrderedDict() + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + for k, v in state_dict.items(): + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + print("=> Loaded state_dict from '{}'".format(args.checkpoint)) + + torch.save(new_state_dict, args.output) + with open(args.output, 'rb') as f: + sha_hash = hashlib.sha256(f.read()).hexdigest() + print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) + else: + print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) + + +if __name__ == '__main__': + main() diff --git a/models/resnet.py b/models/resnet.py index 7492d75e..514564ea 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -16,13 +16,14 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d'] -def _cfg(url=''): +def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs } @@ -32,7 +33,8 @@ default_cfgs = { 'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), 'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), 'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), - 'resnext50_32x4d': _cfg(url=''), + 'resnext50_32x4d': _cfg(url='https://www.dropbox.com/s/yxci33lfew51p6a/resnext50_32x4d-068914d1.pth?dl=1', + interpolation='bicubic'), 'resnext101_32x4d': _cfg(url=''), 'resnext101_64x4d': _cfg(url=''), 'resnext152_32x4d': _cfg(url=''), diff --git a/models/senet.py b/models/senet.py index 75ffc398..c8be0769 100644 --- a/models/senet.py +++ b/models/senet.py @@ -23,7 +23,7 @@ __all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152', 'seresnext50_32x4d', 'seresnext101_32x4d'] -def _cfg(url=''): +def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear',