diff --git a/README.md b/README.md index 596346af..745fd571 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec * DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene) * DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107 * Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks - * EfficientNet (B0-B4) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights + * EfficientNet (B0-B5) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights * MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626) * MobileNet-V1 (https://arxiv.org/abs/1704.04861) * MobileNet-V2 (https://arxiv.org/abs/1801.04381) @@ -187,9 +187,6 @@ To run inference from a checkpoint: ## TODO A number of additions planned in the future for various projects, incl -* Find optimal training hyperparams and create/port pretraiend weights for the generic MobileNet variants * Do a model performance (speed + accuracy) benchmarking across all models (make runable as script) -* More training experiments -* Make folder/file layout compat with usage as a module * Add usage examples to comments, good hyper params for training * Comments, cleanup and the usual things that get pushed back diff --git a/inference.py b/inference.py index 5aeb258f..9077cc07 100644 --- a/inference.py +++ b/inference.py @@ -8,12 +8,13 @@ from __future__ import print_function import os import time import argparse +import logging import numpy as np import torch from timm.models import create_model, apply_test_time_pool from timm.data import Dataset, create_loader, resolve_data_config -from timm.utils import AverageMeter +from timm.utils import AverageMeter, setup_default_logging torch.backends.cudnn.benchmark = True @@ -38,8 +39,8 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset') -parser.add_argument('--print-freq', '-p', default=10, type=int, - metavar='N', help='print frequency (default: 10)') +parser.add_argument('--log-freq', default=10, type=int, + metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--pretrained', dest='pretrained', action='store_true', @@ -53,8 +54,8 @@ parser.add_argument('--topk', default=5, type=int, def main(): + setup_default_logging() args = parser.parse_args() - # might as well try to do something useful... args.pretrained = args.pretrained or not args.checkpoint @@ -66,8 +67,8 @@ def main(): pretrained=args.pretrained, checkpoint_path=args.checkpoint) - print('Model %s created, param count: %d' % - (args.model, sum([m.numel() for m in model.parameters()]))) + logging.info('Model %s created, param count: %d' % + (args.model, sum([m.numel() for m in model.parameters()]))) config = resolve_data_config(model, args) model, test_time_pool = apply_test_time_pool(model, config, args) @@ -105,9 +106,8 @@ def main(): batch_time.update(time.time() - end) end = time.time() - if batch_idx % args.print_freq == 0: - print('Predict: [{0}/{1}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( + if batch_idx % args.log_freq == 0: + logging.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( batch_idx, len(loader), batch_time=batch_time)) topk_ids = np.concatenate(topk_ids, axis=0).squeeze() diff --git a/setup.py b/setup.py index 2c00a20a..1d5ecbc5 100644 --- a/setup.py +++ b/setup.py @@ -19,21 +19,27 @@ setup( url='https://github.com/rwightman/pytorch-image-models', author='Ross Wightman', author_email='hello@rwightman.com', - classifiers=[ # Optional + classifiers=[ # How mature is this project? Common values are # 3 - Alpha # 4 - Beta # 5 - Production/Stable 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Build Tools', - 'License :: OSI Approved :: Apache License', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', ], # Note that this is a string of words separated by whitespace, not a list. keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet', packages=find_packages(exclude=['convert']), - install_requires=['torch', 'torchvision'], + install_requires=['torch >= 1.0', 'torchvision'], python_requires='>=3.6', ) diff --git a/timm/data/config.py b/timm/data/config.py index 29d6f9e3..1675d2a9 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -1,3 +1,4 @@ +import logging from .constants import * @@ -56,9 +57,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): new_config['crop_pct'] = default_cfg['crop_pct'] if verbose: - print('Data processing configuration for current model + dataset:') + logging.info('Data processing configuration for current model + dataset:') for n, v in new_config.items(): - print('\t%s: %s' % (n, str(v))) + logging.info('\t%s: %s' % (n, str(v))) return new_config diff --git a/timm/models/adaptive_avgmax_pool.py b/timm/models/adaptive_avgmax_pool.py index 9dee407f..c1db890b 100644 --- a/timm/models/adaptive_avgmax_pool.py +++ b/timm/models/adaptive_avgmax_pool.py @@ -82,7 +82,7 @@ class SelectAdaptivePool2d(nn.Module): self.pool = nn.AdaptiveMaxPool2d(output_size) else: if pool_type != 'avg': - print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type) + assert False, 'Invalid pool type: %s' % pool_type self.pool = nn.AdaptiveAvgPool2d(output_size) def forward(self, x): diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 5f9aeb35..b8e556db 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -86,7 +86,6 @@ def densenet161(num_classes=1000, in_chans=3, pretrained=False, **kwargs): r"""Densenet-201 model from `"Densely Connected Convolutional Networks" ` """ - print(num_classes, in_chans, pretrained) default_cfg = default_cfgs['densenet161'] model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), num_classes=num_classes, in_chans=in_chans, **kwargs) diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index 3d26d4f6..06f59bcf 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -17,6 +17,7 @@ Hacked together by Ross Wightman import math import re +import logging from copy import deepcopy import torch @@ -336,7 +337,7 @@ class _BlockBuilder: ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn assert ba['act_fn'] is not None if self.verbose: - print('args:', ba) + logging.info(' Args: {}'.format(str(ba))) # could replace this if with lambdas or functools binding if variety increases if bt == 'ir': ba['drop_connect_rate'] = self.drop_connect_rate @@ -358,7 +359,7 @@ class _BlockBuilder: # each stack (stage) contains a list of block arguments for block_idx, ba in enumerate(stack_args): if self.verbose: - print('block', block_idx, end=', ') + logging.info(' Block: {}'.format(block_idx)) if block_idx >= 1: # only the first block in any stack/stage can have a stride > 1 ba['stride'] = 1 @@ -370,24 +371,22 @@ class _BlockBuilder: """ Build the blocks Args: in_chs: Number of input-channels passed to first block - arch_def: A list of lists, outer list defines stacks (or stages), inner + block_args: A list of lists, outer list defines stages, inner list contains strings defining block configuration(s) Return: List of block stacks (each stack wrapped in nn.Sequential) """ if self.verbose: - print('Building model trunk with %d stacks (stages)...' % len(block_args)) + logging.info('Building model trunk with %d stages...' % len(block_args)) self.in_chs = in_chs blocks = [] # outer list of block_args defines the stacks ('stages' by some conventions) for stack_idx, stack in enumerate(block_args): if self.verbose: - print('stack', stack_idx) + logging.info('Stack: {}'.format(stack_idx)) assert isinstance(stack, list) stack = self._make_stack(stack) blocks.append(stack) - if self.verbose: - print() return blocks diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 030a37e4..b7de304a 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -1,6 +1,7 @@ import torch import torch.utils.model_zoo as model_zoo import os +import logging from collections import OrderedDict @@ -21,9 +22,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False): model.load_state_dict(new_state_dict) else: model.load_state_dict(checkpoint) - print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path)) + logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path)) else: - print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) + logging.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() @@ -40,27 +41,27 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None): if 'optimizer' in checkpoint: optimizer_state = checkpoint['optimizer'] start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch - print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) else: model.load_state_dict(checkpoint) start_epoch = 0 if start_epoch is None else start_epoch - print("=> Loaded checkpoint '{}'".format(checkpoint_path)) + logging.info("Loaded checkpoint '{}'".format(checkpoint_path)) return optimizer_state, start_epoch else: - print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) + logging.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None): if 'url' not in default_cfg or not default_cfg['url']: - print("Warning: pretrained model URL is invalid, using random initialization.") + logging.warning("Pretrained model URL is invalid, using random initialization.") return state_dict = model_zoo.load_url(default_cfg['url']) if in_chans == 1: conv1_name = default_cfg['first_conv'] - print('Converting first conv (%s) from 3 to 1 channel' % conv1_name) + logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name) conv1_weight = state_dict[conv1_name + '.weight'] state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True) elif in_chans != 3: diff --git a/timm/models/test_time_pool.py b/timm/models/test_time_pool.py index ec36380b..7d5bb571 100644 --- a/timm/models/test_time_pool.py +++ b/timm/models/test_time_pool.py @@ -1,3 +1,4 @@ +import logging from torch import nn import torch.nn.functional as F from .adaptive_avgmax_pool import adaptive_avgmax_pool2d @@ -31,8 +32,8 @@ def apply_test_time_pool(model, config, args): if not args.no_test_pool and \ config['input_size'][-1] > model.default_cfg['input_size'][-1] and \ config['input_size'][-2] > model.default_cfg['input_size'][-2]: - print('Target input size %s > pretrained default %s, using test time pooling' % - (str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:]))) + logging.info('Target input size %s > pretrained default %s, using test time pooling' % + (str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:]))) model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) test_time_pool = True return model, test_time_pool diff --git a/timm/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py index 795f974d..cb257d0b 100644 --- a/timm/scheduler/tanh_lr.py +++ b/timm/scheduler/tanh_lr.py @@ -50,7 +50,6 @@ class TanhLRScheduler(Scheduler): self.t_in_epochs = t_in_epochs if self.warmup_t: t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) - print(t_v) self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] super().update_groups(self.warmup_lr_init) else: diff --git a/timm/utils.py b/timm/utils.py index 626ae9dc..90efdf1b 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -8,6 +8,7 @@ import shutil import glob import csv import operator +import logging import numpy as np from collections import OrderedDict @@ -18,7 +19,7 @@ def get_state_dict(model): if isinstance(model, ModelEma): return get_state_dict(model.ema) else: - return model.module.state_dict() if getattr(model, 'module') else model.state_dict() + return model.module.state_dict() if hasattr(model, 'module') else model.state_dict() class CheckpointSaver: @@ -29,7 +30,6 @@ class CheckpointSaver: checkpoint_dir='', recovery_dir='', decreasing=False, - verbose=True, max_history=10): # state @@ -47,7 +47,6 @@ class CheckpointSaver: self.extension = '.pth.tar' self.decreasing = decreasing # a lower metric is better if True self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs - self.verbose = verbose self.max_history = max_history assert self.max_history >= 1 @@ -66,11 +65,6 @@ class CheckpointSaver: self.checkpoint_files, key=lambda x: x[1], reverse=not self.decreasing) # sort in descending order if a lower metric is not better - if self.verbose: - print("Current checkpoints:") - for c in self.checkpoint_files: - print(c) - if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): self.best_epoch = epoch self.best_metric = metric @@ -100,11 +94,10 @@ class CheckpointSaver: to_delete = self.checkpoint_files[delete_index:] for d in to_delete: try: - if self.verbose: - print('Cleaning checkpoint: ', d) + logging.debug("Cleaning checkpoint: {}".format(d)) os.remove(d[0]) except Exception as e: - print('Exception (%s) while deleting checkpoint' % str(e)) + logging.error("Exception '{}' while deleting checkpoint".format(e)) self.checkpoint_files = self.checkpoint_files[:delete_index] def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0): @@ -114,11 +107,10 @@ class CheckpointSaver: self._save(save_path, model, optimizer, args, epoch, model_ema) if os.path.exists(self.last_recovery_file): try: - if self.verbose: - print('Cleaning recovery', self.last_recovery_file) + logging.debug("Cleaning recovery: {}".format(self.last_recovery_file)) os.remove(self.last_recovery_file) except Exception as e: - print("Exception (%s) while removing %s" % (str(e), self.last_recovery_file)) + logging.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) self.last_recovery_file = self.curr_recovery_file self.curr_recovery_file = save_path @@ -253,9 +245,9 @@ class ModelEma: name = k new_state_dict[name] = v self.ema.load_state_dict(new_state_dict) - print("=> Loaded state_dict_ema") + logging.info("Loaded state_dict_ema") else: - print("=> Failed to find state_dict_ema, starting from loaded model weights") + logging.warning("Failed to find state_dict_ema, starting from loaded model weights") def update(self, model): # correct a mismatch in state dict keys @@ -269,3 +261,20 @@ class ModelEma: if self.device: model_v = model_v.to(device=self.device) ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + + +class FormatterNoInfo(logging.Formatter): + def __init__(self, fmt='%(levelname)s: %(message)s'): + logging.Formatter.__init__(self, fmt) + + def format(self, record): + if record.levelno == logging.INFO: + return str(record.getMessage()) + return logging.Formatter.format(self, record) + + +def setup_default_logging(default_level=logging.INFO): + console_handler = logging.StreamHandler() + console_handler.setFormatter(FormatterNoInfo()) + logging.root.addHandler(console_handler) + logging.root.setLevel(default_level) diff --git a/train.py b/train.py index b2b033c8..1bf482e0 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,7 @@ import argparse import time +import logging from datetime import datetime try: @@ -127,14 +128,14 @@ parser.add_argument("--local_rank", default=0, type=int) def main(): + setup_default_logging() args = parser.parse_args() - args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: - print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') + logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') args.num_gpu = 1 args.device = 'cuda:0' @@ -144,17 +145,16 @@ def main(): args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) - torch.distributed.init_process_group( - backend='nccl', init_method='env://') + torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: - print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' - % (args.rank, args.world_size)) + logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' + % (args.rank, args.world_size)) else: - print('Training with a single process on %d GPUs.' % args.num_gpu) + logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) @@ -169,8 +169,8 @@ def main(): bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) - print('Model %s created, param count: %d' % - (args.model, sum([m.numel() for m in model.parameters()]))) + logging.info('Model %s created, param count: %d' % + (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) @@ -182,8 +182,8 @@ def main(): if args.num_gpu > 1: if args.amp: - print('Warning: AMP does not work well with nn.DataParallel, disabling. ' - 'Use distributed mode for multi-GPU AMP.') + logging.warning( + 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.') args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: @@ -198,10 +198,10 @@ def main(): if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True - print('AMP enabled') + logging.info('AMP enabled') else: use_amp = False - print('AMP disabled') + logging.info('AMP disabled') model_ema = None if args.model_ema: @@ -222,11 +222,11 @@ def main(): if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: - print('Scheduled epochs: ', num_epochs) + logging.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): - print('Error: training folder does not exist at: %s' % train_dir) + logging.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) @@ -252,7 +252,7 @@ def main(): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): - print('Error: validation folder does not exist at: %s' % eval_dir) + logging.error('Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) @@ -332,7 +332,7 @@ def main(): except KeyboardInterrupt: pass if best_metric is not None: - print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) + logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) def train_epoch( @@ -394,21 +394,22 @@ def train_epoch( losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: - print('Train: {} [{}/{} ({:.0f}%)] ' - 'Loss: {loss.val:.6f} ({loss.avg:.4f}) ' - 'Time: {batch_time.val:.3f}s, {rate:.3f}/s ' - '({batch_time.avg:.3f}s, {rate_avg:.3f}/s) ' - 'LR: {lr:.4f} ' - 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( - epoch, - batch_idx, len(loader), - 100. * batch_idx / last_idx, - loss=losses_m, - batch_time=batch_time_m, - rate=input.size(0) * args.world_size / batch_time_m.val, - rate_avg=input.size(0) * args.world_size / batch_time_m.avg, - lr=lr, - data_time=data_time_m)) + logging.info( + 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' + 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' + 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' + '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' + 'LR: {lr:.3e} ' + 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( + epoch, + batch_idx, len(loader), + 100. * batch_idx / last_idx, + loss=losses_m, + batch_time=batch_time_m, + rate=input.size(0) * args.world_size / batch_time_m.val, + rate_avg=input.size(0) * args.world_size / batch_time_m.avg, + lr=lr, + data_time=data_time_m)) if args.save_images and output_dir: torchvision.utils.save_image( @@ -478,14 +479,15 @@ def validate(model, loader, loss_fn, args, log_suffix=''): end = time.time() if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix - print('{0}: [{1}/{2}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' - 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' - 'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) ' - 'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format( - log_name, batch_idx, last_idx, - batch_time=batch_time_m, loss=losses_m, - top1=prec1_m, top5=prec5_m)) + logging.info( + '{0}: [{1:>4d}/{2}] ' + 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' + 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' + 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( + log_name, batch_idx, last_idx, + batch_time=batch_time_m, loss=losses_m, + top1=prec1_m, top5=prec5_m)) metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)]) diff --git a/validate.py b/validate.py index 453ea514..af7e343b 100644 --- a/validate.py +++ b/validate.py @@ -7,6 +7,7 @@ import os import csv import glob import time +import logging import torch import torch.nn as nn import torch.nn.parallel @@ -14,7 +15,7 @@ from collections import OrderedDict from timm.models import create_model, apply_test_time_pool, load_checkpoint from timm.data import Dataset, create_loader, resolve_data_config -from timm.utils import accuracy, AverageMeter, natural_key +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging torch.backends.cudnn.benchmark = True @@ -37,8 +38,8 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset') -parser.add_argument('--print-freq', '-p', default=10, type=int, - metavar='N', help='print frequency (default: 10)') +parser.add_argument('--log-freq', default=10, type=int, + metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--pretrained', dest='pretrained', action='store_true', @@ -68,7 +69,7 @@ def validate(args): load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) - print('Model %s created, param count: %d' % (args.model, param_count)) + logging.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(model, args) model, test_time_pool = apply_test_time_pool(model, data_config, args) @@ -118,28 +119,30 @@ def validate(args): batch_time.update(time.time() - end) end = time.time() - if i % args.print_freq == 0: - print('Test: [{0}/{1}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' - 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( - i, len(loader), batch_time=batch_time, - rate_avg=input.size(0) / batch_time.avg, - loss=losses, top1=top1, top5=top5)) + if i % args.log_freq == 0: + logging.info( + 'Test: [{0:>4d}/{1}] ' + 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' + 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' + 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( + i, len(loader), batch_time=batch_time, + rate_avg=input.size(0) / batch_time.avg, + loss=losses, top1=top1, top5=top5)) results = OrderedDict( top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3), top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3), param_count=round(param_count / 1e6, 2)) - print(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( + logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results def main(): + setup_default_logging() args = parser.parse_args() if args.model == 'all': # validate all models in a list of names with pretrained checkpoints