Remove all prints, change most to logging calls, tweak alignment of batch logs, improve setup.py

pull/16/head
Ross Wightman 5 years ago
parent 1d7f2d93a6
commit 6fc886acaf

@ -30,7 +30,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
* EfficientNet (B0-B4) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights
* EfficientNet (B0-B5) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
* MobileNet-V1 (https://arxiv.org/abs/1704.04861)
* MobileNet-V2 (https://arxiv.org/abs/1801.04381)
@ -187,9 +187,6 @@ To run inference from a checkpoint:
## TODO
A number of additions planned in the future for various projects, incl
* Find optimal training hyperparams and create/port pretraiend weights for the generic MobileNet variants
* Do a model performance (speed + accuracy) benchmarking across all models (make runable as script)
* More training experiments
* Make folder/file layout compat with usage as a module
* Add usage examples to comments, good hyper params for training
* Comments, cleanup and the usual things that get pushed back

@ -8,12 +8,13 @@ from __future__ import print_function
import os
import time
import argparse
import logging
import numpy as np
import torch
from timm.models import create_model, apply_test_time_pool
from timm.data import Dataset, create_loader, resolve_data_config
from timm.utils import AverageMeter
from timm.utils import AverageMeter, setup_default_logging
torch.backends.cudnn.benchmark = True
@ -38,8 +39,8 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
@ -53,8 +54,8 @@ parser.add_argument('--topk', default=5, type=int,
def main():
setup_default_logging()
args = parser.parse_args()
# might as well try to do something useful...
args.pretrained = args.pretrained or not args.checkpoint
@ -66,8 +67,8 @@ def main():
pretrained=args.pretrained,
checkpoint_path=args.checkpoint)
print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
logging.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
config = resolve_data_config(model, args)
model, test_time_pool = apply_test_time_pool(model, config, args)
@ -105,9 +106,8 @@ def main():
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % args.print_freq == 0:
print('Predict: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
if batch_idx % args.log_freq == 0:
logging.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
batch_idx, len(loader), batch_time=batch_time))
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()

@ -19,21 +19,27 @@ setup(
url='https://github.com/rwightman/pytorch-image-models',
author='Ross Wightman',
author_email='hello@rwightman.com',
classifiers=[ # Optional
classifiers=[
# How mature is this project? Common values are
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'Topic :: Software Development :: Build Tools',
'License :: OSI Approved :: Apache License',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
],
# Note that this is a string of words separated by whitespace, not a list.
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet',
packages=find_packages(exclude=['convert']),
install_requires=['torch', 'torchvision'],
install_requires=['torch >= 1.0', 'torchvision'],
python_requires='>=3.6',
)

@ -1,3 +1,4 @@
import logging
from .constants import *
@ -56,9 +57,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
new_config['crop_pct'] = default_cfg['crop_pct']
if verbose:
print('Data processing configuration for current model + dataset:')
logging.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items():
print('\t%s: %s' % (n, str(v)))
logging.info('\t%s: %s' % (n, str(v)))
return new_config

@ -82,7 +82,7 @@ class SelectAdaptivePool2d(nn.Module):
self.pool = nn.AdaptiveMaxPool2d(output_size)
else:
if pool_type != 'avg':
print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
assert False, 'Invalid pool type: %s' % pool_type
self.pool = nn.AdaptiveAvgPool2d(output_size)
def forward(self, x):

@ -86,7 +86,6 @@ def densenet161(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
print(num_classes, in_chans, pretrained)
default_cfg = default_cfgs['densenet161']
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
num_classes=num_classes, in_chans=in_chans, **kwargs)

@ -17,6 +17,7 @@ Hacked together by Ross Wightman
import math
import re
import logging
from copy import deepcopy
import torch
@ -336,7 +337,7 @@ class _BlockBuilder:
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
assert ba['act_fn'] is not None
if self.verbose:
print('args:', ba)
logging.info(' Args: {}'.format(str(ba)))
# could replace this if with lambdas or functools binding if variety increases
if bt == 'ir':
ba['drop_connect_rate'] = self.drop_connect_rate
@ -358,7 +359,7 @@ class _BlockBuilder:
# each stack (stage) contains a list of block arguments
for block_idx, ba in enumerate(stack_args):
if self.verbose:
print('block', block_idx, end=', ')
logging.info(' Block: {}'.format(block_idx))
if block_idx >= 1:
# only the first block in any stack/stage can have a stride > 1
ba['stride'] = 1
@ -370,24 +371,22 @@ class _BlockBuilder:
""" Build the blocks
Args:
in_chs: Number of input-channels passed to first block
arch_def: A list of lists, outer list defines stacks (or stages), inner
block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
if self.verbose:
print('Building model trunk with %d stacks (stages)...' % len(block_args))
logging.info('Building model trunk with %d stages...' % len(block_args))
self.in_chs = in_chs
blocks = []
# outer list of block_args defines the stacks ('stages' by some conventions)
for stack_idx, stack in enumerate(block_args):
if self.verbose:
print('stack', stack_idx)
logging.info('Stack: {}'.format(stack_idx))
assert isinstance(stack, list)
stack = self._make_stack(stack)
blocks.append(stack)
if self.verbose:
print()
return blocks

@ -1,6 +1,7 @@
import torch
import torch.utils.model_zoo as model_zoo
import os
import logging
from collections import OrderedDict
@ -21,9 +22,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
@ -40,27 +41,27 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer']
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
model.load_state_dict(checkpoint)
start_epoch = 0 if start_epoch is None else start_epoch
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
return optimizer_state, start_epoch
else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None):
if 'url' not in default_cfg or not default_cfg['url']:
print("Warning: pretrained model URL is invalid, using random initialization.")
logging.warning("Pretrained model URL is invalid, using random initialization.")
return
state_dict = model_zoo.load_url(default_cfg['url'])
if in_chans == 1:
conv1_name = default_cfg['first_conv']
print('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
conv1_weight = state_dict[conv1_name + '.weight']
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:

@ -1,3 +1,4 @@
import logging
from torch import nn
import torch.nn.functional as F
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
@ -31,8 +32,8 @@ def apply_test_time_pool(model, config, args):
if not args.no_test_pool and \
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
print('Target input size %s > pretrained default %s, using test time pooling' %
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
logging.info('Target input size %s > pretrained default %s, using test time pooling' %
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
test_time_pool = True
return model, test_time_pool

@ -50,7 +50,6 @@ class TanhLRScheduler(Scheduler):
self.t_in_epochs = t_in_epochs
if self.warmup_t:
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
print(t_v)
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
super().update_groups(self.warmup_lr_init)
else:

@ -8,6 +8,7 @@ import shutil
import glob
import csv
import operator
import logging
import numpy as np
from collections import OrderedDict
@ -18,7 +19,7 @@ def get_state_dict(model):
if isinstance(model, ModelEma):
return get_state_dict(model.ema)
else:
return model.module.state_dict() if getattr(model, 'module') else model.state_dict()
return model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
class CheckpointSaver:
@ -29,7 +30,6 @@ class CheckpointSaver:
checkpoint_dir='',
recovery_dir='',
decreasing=False,
verbose=True,
max_history=10):
# state
@ -47,7 +47,6 @@ class CheckpointSaver:
self.extension = '.pth.tar'
self.decreasing = decreasing # a lower metric is better if True
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
self.verbose = verbose
self.max_history = max_history
assert self.max_history >= 1
@ -66,11 +65,6 @@ class CheckpointSaver:
self.checkpoint_files, key=lambda x: x[1],
reverse=not self.decreasing) # sort in descending order if a lower metric is not better
if self.verbose:
print("Current checkpoints:")
for c in self.checkpoint_files:
print(c)
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
self.best_epoch = epoch
self.best_metric = metric
@ -100,11 +94,10 @@ class CheckpointSaver:
to_delete = self.checkpoint_files[delete_index:]
for d in to_delete:
try:
if self.verbose:
print('Cleaning checkpoint: ', d)
logging.debug("Cleaning checkpoint: {}".format(d))
os.remove(d[0])
except Exception as e:
print('Exception (%s) while deleting checkpoint' % str(e))
logging.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
@ -114,11 +107,10 @@ class CheckpointSaver:
self._save(save_path, model, optimizer, args, epoch, model_ema)
if os.path.exists(self.last_recovery_file):
try:
if self.verbose:
print('Cleaning recovery', self.last_recovery_file)
logging.debug("Cleaning recovery: {}".format(self.last_recovery_file))
os.remove(self.last_recovery_file)
except Exception as e:
print("Exception (%s) while removing %s" % (str(e), self.last_recovery_file))
logging.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
self.last_recovery_file = self.curr_recovery_file
self.curr_recovery_file = save_path
@ -253,9 +245,9 @@ class ModelEma:
name = k
new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict)
print("=> Loaded state_dict_ema")
logging.info("Loaded state_dict_ema")
else:
print("=> Failed to find state_dict_ema, starting from loaded model weights")
logging.warning("Failed to find state_dict_ema, starting from loaded model weights")
def update(self, model):
# correct a mismatch in state dict keys
@ -269,3 +261,20 @@ class ModelEma:
if self.device:
model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
class FormatterNoInfo(logging.Formatter):
def __init__(self, fmt='%(levelname)s: %(message)s'):
logging.Formatter.__init__(self, fmt)
def format(self, record):
if record.levelno == logging.INFO:
return str(record.getMessage())
return logging.Formatter.format(self, record)
def setup_default_logging(default_level=logging.INFO):
console_handler = logging.StreamHandler()
console_handler.setFormatter(FormatterNoInfo())
logging.root.addHandler(console_handler)
logging.root.setLevel(default_level)

@ -1,6 +1,7 @@
import argparse
import time
import logging
from datetime import datetime
try:
@ -127,14 +128,14 @@ parser.add_argument("--local_rank", default=0, type=int)
def main():
setup_default_logging()
args = parser.parse_args()
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
args.num_gpu = 1
args.device = 'cuda:0'
@ -144,17 +145,16 @@ def main():
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend='nccl', init_method='env://')
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
assert args.rank >= 0
if args.distributed:
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
print('Training with a single process on %d GPUs.' % args.num_gpu)
logging.info('Training with a single process on %d GPUs.' % args.num_gpu)
torch.manual_seed(args.seed + args.rank)
@ -169,8 +169,8 @@ def main():
bn_eps=args.bn_eps,
checkpoint_path=args.initial_checkpoint)
print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
logging.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
@ -182,8 +182,8 @@ def main():
if args.num_gpu > 1:
if args.amp:
print('Warning: AMP does not work well with nn.DataParallel, disabling. '
'Use distributed mode for multi-GPU AMP.')
logging.warning(
'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
args.amp = False
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else:
@ -198,10 +198,10 @@ def main():
if has_apex and args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
use_amp = True
print('AMP enabled')
logging.info('AMP enabled')
else:
use_amp = False
print('AMP disabled')
logging.info('AMP disabled')
model_ema = None
if args.model_ema:
@ -222,11 +222,11 @@ def main():
if start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
print('Scheduled epochs: ', num_epochs)
logging.info('Scheduled epochs: {}'.format(num_epochs))
train_dir = os.path.join(args.data, 'train')
if not os.path.exists(train_dir):
print('Error: training folder does not exist at: %s' % train_dir)
logging.error('Training folder does not exist at: {}'.format(train_dir))
exit(1)
dataset_train = Dataset(train_dir)
@ -252,7 +252,7 @@ def main():
eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir):
print('Error: validation folder does not exist at: %s' % eval_dir)
logging.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
@ -332,7 +332,7 @@ def main():
except KeyboardInterrupt:
pass
if best_metric is not None:
print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch(
@ -394,21 +394,22 @@ def train_epoch(
losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0:
print('Train: {} [{}/{} ({:.0f}%)] '
'Loss: {loss.val:.6f} ({loss.avg:.4f}) '
'Time: {batch_time.val:.3f}s, {rate:.3f}/s '
'({batch_time.avg:.3f}s, {rate_avg:.3f}/s) '
'LR: {lr:.4f} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
logging.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
if args.save_images and output_dir:
torchvision.utils.save_image(
@ -478,14 +479,15 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
print('{0}: [{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
log_name, batch_idx, last_idx,
batch_time=batch_time_m, loss=losses_m,
top1=prec1_m, top5=prec5_m))
logging.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name, batch_idx, last_idx,
batch_time=batch_time_m, loss=losses_m,
top1=prec1_m, top5=prec5_m))
metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])

@ -7,6 +7,7 @@ import os
import csv
import glob
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
@ -14,7 +15,7 @@ from collections import OrderedDict
from timm.models import create_model, apply_test_time_pool, load_checkpoint
from timm.data import Dataset, create_loader, resolve_data_config
from timm.utils import accuracy, AverageMeter, natural_key
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
torch.backends.cudnn.benchmark = True
@ -37,8 +38,8 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
@ -68,7 +69,7 @@ def validate(args):
load_checkpoint(model, args.checkpoint, args.use_ema)
param_count = sum([m.numel() for m in model.parameters()])
print('Model %s created, param count: %d' % (args.model, param_count))
logging.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(model, args)
model, test_time_pool = apply_test_time_pool(model, data_config, args)
@ -118,28 +119,30 @@ def validate(args):
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(loader), batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
if i % args.log_freq == 0:
logging.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
i, len(loader), batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
results = OrderedDict(
top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
param_count=round(param_count / 1e6, 2))
print(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
return results
def main():
setup_default_logging()
args = parser.parse_args()
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints

Loading…
Cancel
Save