Add checkpoint averaging script. Add headers, shebangs, exec perms to all scripts

pull/82/head
Ross Wightman 5 years ago
parent 4666cc9aed
commit 40fea63ebe

@ -0,0 +1,113 @@
#!/usr/bin/env python
""" Checkpoint Averaging Script
This script averages all model weights for checkpoints in specified path that match
the specified filter wildcard. All checkpoints must be from the exact same model.
For any hope of decent results, the checkpoints should be from the same or child
(via resumes) training session. This can be viewed as similar to maintaining running
EMA (exponential moving average) of the model weights or performing SWA (stochastic
weight averaging), but post-training.
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import torch
import argparse
import os
import glob
import hashlib
from timm.models.helpers import load_state_dict
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH',
help='path to base input folder containing checkpoints')
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
help='checkpoint filter (path wildcard)')
parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH',
help='output filename')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='Force not using ema version of weights (if present)')
parser.add_argument('--no-sort', dest='no_sort', action='store_true',
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
parser.add_argument('-n', type=int, default=10, metavar='N',
help='Number of checkpoints to average')
def checkpoint_metric(checkpoint_path):
if not checkpoint_path or not os.path.isfile(checkpoint_path):
return {}
print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path, map_location='cpu')
metric = None
if 'metric' in checkpoint:
metric = checkpoint['metric']
return metric
def main():
args = parser.parse_args()
# by default use the EMA weights (if present)
args.use_ema = not args.no_use_ema
# by default sort by checkpoint metric (if present) and avg top n checkpoints
args.sort = not args.no_sort
if os.path.exists(args.output):
print("Error: Output filename ({}) already exists.".format(args.output))
exit(1)
pattern = args.input
if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep):
pattern += os.path.sep
pattern += args.filter
checkpoints = glob.glob(pattern, recursive=True)
if args.sort:
checkpoint_metrics = []
for c in checkpoints:
metric = checkpoint_metric(c)
if metric is not None:
checkpoint_metrics.append((metric, c))
checkpoint_metrics = list(sorted(checkpoint_metrics))
checkpoint_metrics = checkpoint_metrics[-args.n:]
print("Selected checkpoints:")
[print(m, c) for m, c in checkpoint_metrics]
avg_checkpoints = [c for m, c in checkpoint_metrics]
else:
avg_checkpoints = checkpoints
print("Selected checkpoints:")
[print(c) for c in checkpoints]
avg_state_dict = {}
avg_counts = {}
for c in avg_checkpoints:
new_state_dict = load_state_dict(c, args.use_ema)
if not new_state_dict:
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
continue
for k, v in new_state_dict.items():
if k not in avg_state_dict:
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
avg_counts[k] = 1
else:
avg_state_dict[k] += v.to(dtype=torch.float64)
avg_counts[k] += 1
for k, v in avg_state_dict.items():
v.div_(avg_counts[k])
# float32 overflow seems unlikely based on weights seen to date, but who knows
float32_info = torch.finfo(torch.float32)
final_state_dict = {}
for k, v in avg_state_dict.items():
v = v.clamp(float32_info.min, float32_info.max)
final_state_dict[k] = v.to(dtype=torch.float32)
torch.save(final_state_dict, args.output)
with open(args.output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
if __name__ == '__main__':
main()

@ -1,3 +1,12 @@
#!/usr/bin/env python
""" Checkpoint Cleaning Script
Takes training checkpoints with GPU tensors, optimizer state, extra dict keys, etc.
and outputs a CPU tensor checkpoint with only the `state_dict` along with SHA256
calculation for model zoo compatibility.
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import torch import torch
import argparse import argparse
import os import os
@ -5,7 +14,7 @@ import hashlib
import shutil import shutil
from collections import OrderedDict from collections import OrderedDict
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('--output', default='', type=str, metavar='PATH', parser.add_argument('--output', default='', type=str, metavar='PATH',

@ -1,10 +1,10 @@
"""Sample PyTorch Inference script #!/usr/bin/env python
""" """PyTorch Inference Script
from __future__ import absolute_import An example inference script that outputs top-k class ids for images in a folder into a csv.
from __future__ import division
from __future__ import print_function
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import os import os
import time import time
import argparse import argparse

@ -5,12 +5,11 @@ import logging
from collections import OrderedDict from collections import OrderedDict
def load_checkpoint(model, checkpoint_path, use_ema=False): def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = ''
if isinstance(checkpoint, dict):
state_dict_key = 'state_dict' state_dict_key = 'state_dict'
if isinstance(checkpoint, dict):
if use_ema and 'state_dict_ema' in checkpoint: if use_ema and 'state_dict_ema' in checkpoint:
state_dict_key = 'state_dict_ema' state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint: if state_dict_key and state_dict_key in checkpoint:
@ -19,15 +18,21 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
# strip `module.` prefix # strip `module.` prefix
name = k[7:] if k.startswith('module') else k name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v new_state_dict[name] = v
model.load_state_dict(new_state_dict) state_dict = new_state_dict
else: else:
model.load_state_dict(checkpoint) state_dict = checkpoint
logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path)) logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
return state_dict
else: else:
logging.error("No checkpoint found at '{}'".format(checkpoint_path)) logging.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False):
state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict)
def resume_checkpoint(model, checkpoint_path): def resume_checkpoint(model, checkpoint_path):
other_state = {} other_state = {}
resume_epoch = None resume_epoch = None

@ -1,4 +1,19 @@
#!/usr/bin/env python
""" ImageNet Training Script
This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)
NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import argparse import argparse
import time import time
import logging import logging
@ -35,7 +50,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments') help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='Training') parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset / Model parameters # Dataset / Model parameters
parser.add_argument('data', metavar='DIR', parser.add_argument('data', metavar='DIR',
help='path to dataset') help='path to dataset')

@ -1,7 +1,12 @@
from __future__ import absolute_import #!/usr/bin/env python
from __future__ import division """ ImageNet Validation Script
from __future__ import print_function
This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import argparse import argparse
import os import os
import csv import csv
@ -182,6 +187,7 @@ def main():
# validate all checkpoints in a path with same model # validate all checkpoints in a path with same model
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
checkpoints += glob.glob(args.checkpoint + '/*.pth') checkpoints += glob.glob(args.checkpoint + '/*.pth')
model_names = list_models(args.model)
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
else: else:
if args.model == 'all': if args.model == 'all':
@ -195,7 +201,7 @@ def main():
model_cfgs = [(n, '') for n in model_names] model_cfgs = [(n, '') for n in model_names]
if len(model_cfgs): if len(model_cfgs):
print('Running bulk validation on these pretrained models:', ', '.join(model_names)) logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
header_written = False header_written = False
with open('./results-all.csv', mode='w') as cf: with open('./results-all.csv', mode='w') as cf:
for m, c in model_cfgs: for m, c in model_cfgs:

Loading…
Cancel
Save