diff --git a/avg_checkpoint.py b/avg_checkpoint.py new file mode 100755 index 00000000..99b0ab2f --- /dev/null +++ b/avg_checkpoint.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +""" Checkpoint Averaging Script + +This script averages all model weights for checkpoints in specified path that match +the specified filter wildcard. All checkpoints must be from the exact same model. + +For any hope of decent results, the checkpoints should be from the same or child +(via resumes) training session. This can be viewed as similar to maintaining running +EMA (exponential moving average) of the model weights or performing SWA (stochastic +weight averaging), but post-training. + +Hacked together by Ross Wightman (https://github.com/rwightman) +""" +import torch +import argparse +import os +import glob +import hashlib +from timm.models.helpers import load_state_dict + +parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') +parser.add_argument('--input', default='', type=str, metavar='PATH', + help='path to base input folder containing checkpoints') +parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD', + help='checkpoint filter (path wildcard)') +parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH', + help='output filename') +parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', + help='Force not using ema version of weights (if present)') +parser.add_argument('--no-sort', dest='no_sort', action='store_true', + help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant') +parser.add_argument('-n', type=int, default=10, metavar='N', + help='Number of checkpoints to average') + + +def checkpoint_metric(checkpoint_path): + if not checkpoint_path or not os.path.isfile(checkpoint_path): + return {} + print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path, map_location='cpu') + metric = None + if 'metric' in checkpoint: + metric = checkpoint['metric'] + return metric + + +def main(): + args = parser.parse_args() + # by default use the EMA weights (if present) + args.use_ema = not args.no_use_ema + # by default sort by checkpoint metric (if present) and avg top n checkpoints + args.sort = not args.no_sort + + if os.path.exists(args.output): + print("Error: Output filename ({}) already exists.".format(args.output)) + exit(1) + + pattern = args.input + if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep): + pattern += os.path.sep + pattern += args.filter + checkpoints = glob.glob(pattern, recursive=True) + + if args.sort: + checkpoint_metrics = [] + for c in checkpoints: + metric = checkpoint_metric(c) + if metric is not None: + checkpoint_metrics.append((metric, c)) + checkpoint_metrics = list(sorted(checkpoint_metrics)) + checkpoint_metrics = checkpoint_metrics[-args.n:] + print("Selected checkpoints:") + [print(m, c) for m, c in checkpoint_metrics] + avg_checkpoints = [c for m, c in checkpoint_metrics] + else: + avg_checkpoints = checkpoints + print("Selected checkpoints:") + [print(c) for c in checkpoints] + + avg_state_dict = {} + avg_counts = {} + for c in avg_checkpoints: + new_state_dict = load_state_dict(c, args.use_ema) + if not new_state_dict: + print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) + continue + + for k, v in new_state_dict.items(): + if k not in avg_state_dict: + avg_state_dict[k] = v.clone().to(dtype=torch.float64) + avg_counts[k] = 1 + else: + avg_state_dict[k] += v.to(dtype=torch.float64) + avg_counts[k] += 1 + + for k, v in avg_state_dict.items(): + v.div_(avg_counts[k]) + + # float32 overflow seems unlikely based on weights seen to date, but who knows + float32_info = torch.finfo(torch.float32) + final_state_dict = {} + for k, v in avg_state_dict.items(): + v = v.clamp(float32_info.min, float32_info.max) + final_state_dict[k] = v.to(dtype=torch.float32) + + torch.save(final_state_dict, args.output) + with open(args.output, 'rb') as f: + sha_hash = hashlib.sha256(f.read()).hexdigest() + print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) + + +if __name__ == '__main__': + main() diff --git a/clean_checkpoint.py b/clean_checkpoint.py old mode 100644 new mode 100755 index b088aa8f..fef3104a --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -1,3 +1,12 @@ +#!/usr/bin/env python +""" Checkpoint Cleaning Script + +Takes training checkpoints with GPU tensors, optimizer state, extra dict keys, etc. +and outputs a CPU tensor checkpoint with only the `state_dict` along with SHA256 +calculation for model zoo compatibility. + +Hacked together by Ross Wightman (https://github.com/rwightman) +""" import torch import argparse import os @@ -5,7 +14,7 @@ import hashlib import shutil from collections import OrderedDict -parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--output', default='', type=str, metavar='PATH', diff --git a/inference.py b/inference.py old mode 100644 new mode 100755 index 3255a8d9..99370c51 --- a/inference.py +++ b/inference.py @@ -1,10 +1,10 @@ -"""Sample PyTorch Inference script -""" +#!/usr/bin/env python +"""PyTorch Inference Script -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +An example inference script that outputs top-k class ids for images in a folder into a csv. +Hacked together by Ross Wightman (https://github.com/rwightman) +""" import os import time import argparse diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 7460f4a2..84004db5 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -5,12 +5,11 @@ import logging from collections import OrderedDict -def load_checkpoint(model, checkpoint_path, use_ema=False): +def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') - state_dict_key = '' + state_dict_key = 'state_dict' if isinstance(checkpoint, dict): - state_dict_key = 'state_dict' if use_ema and 'state_dict_ema' in checkpoint: state_dict_key = 'state_dict_ema' if state_dict_key and state_dict_key in checkpoint: @@ -19,15 +18,21 @@ def load_checkpoint(model, checkpoint_path, use_ema=False): # strip `module.` prefix name = k[7:] if k.startswith('module') else k new_state_dict[name] = v - model.load_state_dict(new_state_dict) + state_dict = new_state_dict else: - model.load_state_dict(checkpoint) - logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path)) + state_dict = checkpoint + logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) + return state_dict else: logging.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() +def load_checkpoint(model, checkpoint_path, use_ema=False): + state_dict = load_state_dict(checkpoint_path, use_ema) + model.load_state_dict(state_dict) + + def resume_checkpoint(model, checkpoint_path): other_state = {} resume_epoch = None diff --git a/train.py b/train.py old mode 100644 new mode 100755 index b8f37f41..558c29ac --- a/train.py +++ b/train.py @@ -1,4 +1,19 @@ +#!/usr/bin/env python +""" ImageNet Training Script +This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet +training results with some of the latest networks and training techniques. It favours canonical PyTorch +and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed +and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. + +This script was started from an early version of the PyTorch ImageNet example +(https://github.com/pytorch/examples/tree/master/imagenet) + +NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples +(https://github.com/NVIDIA/apex/tree/master/examples/imagenet) + +Hacked together by Ross Wightman (https://github.com/rwightman) +""" import argparse import time import logging @@ -35,7 +50,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', help='YAML config file specifying default arguments') -parser = argparse.ArgumentParser(description='Training') +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset / Model parameters parser.add_argument('data', metavar='DIR', help='path to dataset') diff --git a/validate.py b/validate.py old mode 100644 new mode 100755 index 004393ab..93a82021 --- a/validate.py +++ b/validate.py @@ -1,7 +1,12 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +#!/usr/bin/env python +""" ImageNet Validation Script +This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained +models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes +canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit. + +Hacked together by Ross Wightman (https://github.com/rwightman) +""" import argparse import os import csv @@ -182,6 +187,7 @@ def main(): # validate all checkpoints in a path with same model checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') checkpoints += glob.glob(args.checkpoint + '/*.pth') + model_names = list_models(args.model) model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] else: if args.model == 'all': @@ -195,7 +201,7 @@ def main(): model_cfgs = [(n, '') for n in model_names] if len(model_cfgs): - print('Running bulk validation on these pretrained models:', ', '.join(model_names)) + logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) header_written = False with open('./results-all.csv', mode='w') as cf: for m, c in model_cfgs: