Remove all prints, change most to logging calls, tweak alignment of batch logs, improve setup.py

pull/16/head
Ross Wightman 5 years ago
parent 1d7f2d93a6
commit 6fc886acaf

@ -30,7 +30,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene) * DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107 * DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks * Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
* EfficientNet (B0-B4) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights * EfficientNet (B0-B5) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626) * MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
* MobileNet-V1 (https://arxiv.org/abs/1704.04861) * MobileNet-V1 (https://arxiv.org/abs/1704.04861)
* MobileNet-V2 (https://arxiv.org/abs/1801.04381) * MobileNet-V2 (https://arxiv.org/abs/1801.04381)
@ -187,9 +187,6 @@ To run inference from a checkpoint:
## TODO ## TODO
A number of additions planned in the future for various projects, incl A number of additions planned in the future for various projects, incl
* Find optimal training hyperparams and create/port pretraiend weights for the generic MobileNet variants
* Do a model performance (speed + accuracy) benchmarking across all models (make runable as script) * Do a model performance (speed + accuracy) benchmarking across all models (make runable as script)
* More training experiments
* Make folder/file layout compat with usage as a module
* Add usage examples to comments, good hyper params for training * Add usage examples to comments, good hyper params for training
* Comments, cleanup and the usual things that get pushed back * Comments, cleanup and the usual things that get pushed back

@ -8,12 +8,13 @@ from __future__ import print_function
import os import os
import time import time
import argparse import argparse
import logging
import numpy as np import numpy as np
import torch import torch
from timm.models import create_model, apply_test_time_pool from timm.models import create_model, apply_test_time_pool
from timm.data import Dataset, create_loader, resolve_data_config from timm.data import Dataset, create_loader, resolve_data_config
from timm.utils import AverageMeter from timm.utils import AverageMeter, setup_default_logging
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -38,8 +39,8 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)') help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000, parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset') help='Number classes in dataset')
parser.add_argument('--print-freq', '-p', default=10, type=int, parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='print frequency (default: 10)') metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', parser.add_argument('--pretrained', dest='pretrained', action='store_true',
@ -53,8 +54,8 @@ parser.add_argument('--topk', default=5, type=int,
def main(): def main():
setup_default_logging()
args = parser.parse_args() args = parser.parse_args()
# might as well try to do something useful... # might as well try to do something useful...
args.pretrained = args.pretrained or not args.checkpoint args.pretrained = args.pretrained or not args.checkpoint
@ -66,8 +67,8 @@ def main():
pretrained=args.pretrained, pretrained=args.pretrained,
checkpoint_path=args.checkpoint) checkpoint_path=args.checkpoint)
print('Model %s created, param count: %d' % logging.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
config = resolve_data_config(model, args) config = resolve_data_config(model, args)
model, test_time_pool = apply_test_time_pool(model, config, args) model, test_time_pool = apply_test_time_pool(model, config, args)
@ -105,9 +106,8 @@ def main():
batch_time.update(time.time() - end) batch_time.update(time.time() - end)
end = time.time() end = time.time()
if batch_idx % args.print_freq == 0: if batch_idx % args.log_freq == 0:
print('Predict: [{0}/{1}]\t' logging.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
batch_idx, len(loader), batch_time=batch_time)) batch_idx, len(loader), batch_time=batch_time))
topk_ids = np.concatenate(topk_ids, axis=0).squeeze() topk_ids = np.concatenate(topk_ids, axis=0).squeeze()

@ -19,21 +19,27 @@ setup(
url='https://github.com/rwightman/pytorch-image-models', url='https://github.com/rwightman/pytorch-image-models',
author='Ross Wightman', author='Ross Wightman',
author_email='hello@rwightman.com', author_email='hello@rwightman.com',
classifiers=[ # Optional classifiers=[
# How mature is this project? Common values are # How mature is this project? Common values are
# 3 - Alpha # 3 - Alpha
# 4 - Beta # 4 - Beta
# 5 - Production/Stable # 5 - Production/Stable
'Development Status :: 3 - Alpha', 'Development Status :: 3 - Alpha',
'Intended Audience :: Developers', 'Intended Audience :: Education',
'Topic :: Software Development :: Build Tools', 'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache License', 'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
], ],
# Note that this is a string of words separated by whitespace, not a list. # Note that this is a string of words separated by whitespace, not a list.
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet', keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet',
packages=find_packages(exclude=['convert']), packages=find_packages(exclude=['convert']),
install_requires=['torch', 'torchvision'], install_requires=['torch >= 1.0', 'torchvision'],
python_requires='>=3.6', python_requires='>=3.6',
) )

@ -1,3 +1,4 @@
import logging
from .constants import * from .constants import *
@ -56,9 +57,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
new_config['crop_pct'] = default_cfg['crop_pct'] new_config['crop_pct'] = default_cfg['crop_pct']
if verbose: if verbose:
print('Data processing configuration for current model + dataset:') logging.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items(): for n, v in new_config.items():
print('\t%s: %s' % (n, str(v))) logging.info('\t%s: %s' % (n, str(v)))
return new_config return new_config

@ -82,7 +82,7 @@ class SelectAdaptivePool2d(nn.Module):
self.pool = nn.AdaptiveMaxPool2d(output_size) self.pool = nn.AdaptiveMaxPool2d(output_size)
else: else:
if pool_type != 'avg': if pool_type != 'avg':
print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type) assert False, 'Invalid pool type: %s' % pool_type
self.pool = nn.AdaptiveAvgPool2d(output_size) self.pool = nn.AdaptiveAvgPool2d(output_size)
def forward(self, x): def forward(self, x):

@ -86,7 +86,6 @@ def densenet161(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""Densenet-201 model from r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
print(num_classes, in_chans, pretrained)
default_cfg = default_cfgs['densenet161'] default_cfg = default_cfgs['densenet161']
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
num_classes=num_classes, in_chans=in_chans, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)

@ -17,6 +17,7 @@ Hacked together by Ross Wightman
import math import math
import re import re
import logging
from copy import deepcopy from copy import deepcopy
import torch import torch
@ -336,7 +337,7 @@ class _BlockBuilder:
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
assert ba['act_fn'] is not None assert ba['act_fn'] is not None
if self.verbose: if self.verbose:
print('args:', ba) logging.info(' Args: {}'.format(str(ba)))
# could replace this if with lambdas or functools binding if variety increases # could replace this if with lambdas or functools binding if variety increases
if bt == 'ir': if bt == 'ir':
ba['drop_connect_rate'] = self.drop_connect_rate ba['drop_connect_rate'] = self.drop_connect_rate
@ -358,7 +359,7 @@ class _BlockBuilder:
# each stack (stage) contains a list of block arguments # each stack (stage) contains a list of block arguments
for block_idx, ba in enumerate(stack_args): for block_idx, ba in enumerate(stack_args):
if self.verbose: if self.verbose:
print('block', block_idx, end=', ') logging.info(' Block: {}'.format(block_idx))
if block_idx >= 1: if block_idx >= 1:
# only the first block in any stack/stage can have a stride > 1 # only the first block in any stack/stage can have a stride > 1
ba['stride'] = 1 ba['stride'] = 1
@ -370,24 +371,22 @@ class _BlockBuilder:
""" Build the blocks """ Build the blocks
Args: Args:
in_chs: Number of input-channels passed to first block in_chs: Number of input-channels passed to first block
arch_def: A list of lists, outer list defines stacks (or stages), inner block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s) list contains strings defining block configuration(s)
Return: Return:
List of block stacks (each stack wrapped in nn.Sequential) List of block stacks (each stack wrapped in nn.Sequential)
""" """
if self.verbose: if self.verbose:
print('Building model trunk with %d stacks (stages)...' % len(block_args)) logging.info('Building model trunk with %d stages...' % len(block_args))
self.in_chs = in_chs self.in_chs = in_chs
blocks = [] blocks = []
# outer list of block_args defines the stacks ('stages' by some conventions) # outer list of block_args defines the stacks ('stages' by some conventions)
for stack_idx, stack in enumerate(block_args): for stack_idx, stack in enumerate(block_args):
if self.verbose: if self.verbose:
print('stack', stack_idx) logging.info('Stack: {}'.format(stack_idx))
assert isinstance(stack, list) assert isinstance(stack, list)
stack = self._make_stack(stack) stack = self._make_stack(stack)
blocks.append(stack) blocks.append(stack)
if self.verbose:
print()
return blocks return blocks

@ -1,6 +1,7 @@
import torch import torch
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
import os import os
import logging
from collections import OrderedDict from collections import OrderedDict
@ -21,9 +22,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path)) logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
else: else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) logging.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()
@ -40,27 +41,27 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
if 'optimizer' in checkpoint: if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer'] optimizer_state = checkpoint['optimizer']
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
start_epoch = 0 if start_epoch is None else start_epoch start_epoch = 0 if start_epoch is None else start_epoch
print("=> Loaded checkpoint '{}'".format(checkpoint_path)) logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
return optimizer_state, start_epoch return optimizer_state, start_epoch
else: else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) logging.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None): def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None):
if 'url' not in default_cfg or not default_cfg['url']: if 'url' not in default_cfg or not default_cfg['url']:
print("Warning: pretrained model URL is invalid, using random initialization.") logging.warning("Pretrained model URL is invalid, using random initialization.")
return return
state_dict = model_zoo.load_url(default_cfg['url']) state_dict = model_zoo.load_url(default_cfg['url'])
if in_chans == 1: if in_chans == 1:
conv1_name = default_cfg['first_conv'] conv1_name = default_cfg['first_conv']
print('Converting first conv (%s) from 3 to 1 channel' % conv1_name) logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
conv1_weight = state_dict[conv1_name + '.weight'] conv1_weight = state_dict[conv1_name + '.weight']
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True) state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
elif in_chans != 3: elif in_chans != 3:

@ -1,3 +1,4 @@
import logging
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
@ -31,8 +32,8 @@ def apply_test_time_pool(model, config, args):
if not args.no_test_pool and \ if not args.no_test_pool and \
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \ config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
config['input_size'][-2] > model.default_cfg['input_size'][-2]: config['input_size'][-2] > model.default_cfg['input_size'][-2]:
print('Target input size %s > pretrained default %s, using test time pooling' % logging.info('Target input size %s > pretrained default %s, using test time pooling' %
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:]))) (str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
test_time_pool = True test_time_pool = True
return model, test_time_pool return model, test_time_pool

@ -50,7 +50,6 @@ class TanhLRScheduler(Scheduler):
self.t_in_epochs = t_in_epochs self.t_in_epochs = t_in_epochs
if self.warmup_t: if self.warmup_t:
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
print(t_v)
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
super().update_groups(self.warmup_lr_init) super().update_groups(self.warmup_lr_init)
else: else:

@ -8,6 +8,7 @@ import shutil
import glob import glob
import csv import csv
import operator import operator
import logging
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
@ -18,7 +19,7 @@ def get_state_dict(model):
if isinstance(model, ModelEma): if isinstance(model, ModelEma):
return get_state_dict(model.ema) return get_state_dict(model.ema)
else: else:
return model.module.state_dict() if getattr(model, 'module') else model.state_dict() return model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
class CheckpointSaver: class CheckpointSaver:
@ -29,7 +30,6 @@ class CheckpointSaver:
checkpoint_dir='', checkpoint_dir='',
recovery_dir='', recovery_dir='',
decreasing=False, decreasing=False,
verbose=True,
max_history=10): max_history=10):
# state # state
@ -47,7 +47,6 @@ class CheckpointSaver:
self.extension = '.pth.tar' self.extension = '.pth.tar'
self.decreasing = decreasing # a lower metric is better if True self.decreasing = decreasing # a lower metric is better if True
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
self.verbose = verbose
self.max_history = max_history self.max_history = max_history
assert self.max_history >= 1 assert self.max_history >= 1
@ -66,11 +65,6 @@ class CheckpointSaver:
self.checkpoint_files, key=lambda x: x[1], self.checkpoint_files, key=lambda x: x[1],
reverse=not self.decreasing) # sort in descending order if a lower metric is not better reverse=not self.decreasing) # sort in descending order if a lower metric is not better
if self.verbose:
print("Current checkpoints:")
for c in self.checkpoint_files:
print(c)
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
self.best_epoch = epoch self.best_epoch = epoch
self.best_metric = metric self.best_metric = metric
@ -100,11 +94,10 @@ class CheckpointSaver:
to_delete = self.checkpoint_files[delete_index:] to_delete = self.checkpoint_files[delete_index:]
for d in to_delete: for d in to_delete:
try: try:
if self.verbose: logging.debug("Cleaning checkpoint: {}".format(d))
print('Cleaning checkpoint: ', d)
os.remove(d[0]) os.remove(d[0])
except Exception as e: except Exception as e:
print('Exception (%s) while deleting checkpoint' % str(e)) logging.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index] self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0): def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
@ -114,11 +107,10 @@ class CheckpointSaver:
self._save(save_path, model, optimizer, args, epoch, model_ema) self._save(save_path, model, optimizer, args, epoch, model_ema)
if os.path.exists(self.last_recovery_file): if os.path.exists(self.last_recovery_file):
try: try:
if self.verbose: logging.debug("Cleaning recovery: {}".format(self.last_recovery_file))
print('Cleaning recovery', self.last_recovery_file)
os.remove(self.last_recovery_file) os.remove(self.last_recovery_file)
except Exception as e: except Exception as e:
print("Exception (%s) while removing %s" % (str(e), self.last_recovery_file)) logging.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
self.last_recovery_file = self.curr_recovery_file self.last_recovery_file = self.curr_recovery_file
self.curr_recovery_file = save_path self.curr_recovery_file = save_path
@ -253,9 +245,9 @@ class ModelEma:
name = k name = k
new_state_dict[name] = v new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict) self.ema.load_state_dict(new_state_dict)
print("=> Loaded state_dict_ema") logging.info("Loaded state_dict_ema")
else: else:
print("=> Failed to find state_dict_ema, starting from loaded model weights") logging.warning("Failed to find state_dict_ema, starting from loaded model weights")
def update(self, model): def update(self, model):
# correct a mismatch in state dict keys # correct a mismatch in state dict keys
@ -269,3 +261,20 @@ class ModelEma:
if self.device: if self.device:
model_v = model_v.to(device=self.device) model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
class FormatterNoInfo(logging.Formatter):
def __init__(self, fmt='%(levelname)s: %(message)s'):
logging.Formatter.__init__(self, fmt)
def format(self, record):
if record.levelno == logging.INFO:
return str(record.getMessage())
return logging.Formatter.format(self, record)
def setup_default_logging(default_level=logging.INFO):
console_handler = logging.StreamHandler()
console_handler.setFormatter(FormatterNoInfo())
logging.root.addHandler(console_handler)
logging.root.setLevel(default_level)

@ -1,6 +1,7 @@
import argparse import argparse
import time import time
import logging
from datetime import datetime from datetime import datetime
try: try:
@ -127,14 +128,14 @@ parser.add_argument("--local_rank", default=0, type=int)
def main(): def main():
setup_default_logging()
args = parser.parse_args() args = parser.parse_args()
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
args.distributed = False args.distributed = False
if 'WORLD_SIZE' in os.environ: if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1: if args.distributed and args.num_gpu > 1:
print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
args.num_gpu = 1 args.num_gpu = 1
args.device = 'cuda:0' args.device = 'cuda:0'
@ -144,17 +145,16 @@ def main():
args.num_gpu = 1 args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group( torch.distributed.init_process_group(backend='nccl', init_method='env://')
backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
assert args.rank >= 0 assert args.rank >= 0
if args.distributed: if args.distributed:
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size)) % (args.rank, args.world_size))
else: else:
print('Training with a single process on %d GPUs.' % args.num_gpu) logging.info('Training with a single process on %d GPUs.' % args.num_gpu)
torch.manual_seed(args.seed + args.rank) torch.manual_seed(args.seed + args.rank)
@ -169,8 +169,8 @@ def main():
bn_eps=args.bn_eps, bn_eps=args.bn_eps,
checkpoint_path=args.initial_checkpoint) checkpoint_path=args.initial_checkpoint)
print('Model %s created, param count: %d' % logging.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
@ -182,8 +182,8 @@ def main():
if args.num_gpu > 1: if args.num_gpu > 1:
if args.amp: if args.amp:
print('Warning: AMP does not work well with nn.DataParallel, disabling. ' logging.warning(
'Use distributed mode for multi-GPU AMP.') 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
args.amp = False args.amp = False
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else: else:
@ -198,10 +198,10 @@ def main():
if has_apex and args.amp: if has_apex and args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
use_amp = True use_amp = True
print('AMP enabled') logging.info('AMP enabled')
else: else:
use_amp = False use_amp = False
print('AMP disabled') logging.info('AMP disabled')
model_ema = None model_ema = None
if args.model_ema: if args.model_ema:
@ -222,11 +222,11 @@ def main():
if start_epoch > 0: if start_epoch > 0:
lr_scheduler.step(start_epoch) lr_scheduler.step(start_epoch)
if args.local_rank == 0: if args.local_rank == 0:
print('Scheduled epochs: ', num_epochs) logging.info('Scheduled epochs: {}'.format(num_epochs))
train_dir = os.path.join(args.data, 'train') train_dir = os.path.join(args.data, 'train')
if not os.path.exists(train_dir): if not os.path.exists(train_dir):
print('Error: training folder does not exist at: %s' % train_dir) logging.error('Training folder does not exist at: {}'.format(train_dir))
exit(1) exit(1)
dataset_train = Dataset(train_dir) dataset_train = Dataset(train_dir)
@ -252,7 +252,7 @@ def main():
eval_dir = os.path.join(args.data, 'validation') eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir): if not os.path.isdir(eval_dir):
print('Error: validation folder does not exist at: %s' % eval_dir) logging.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1) exit(1)
dataset_eval = Dataset(eval_dir) dataset_eval = Dataset(eval_dir)
@ -332,7 +332,7 @@ def main():
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
if best_metric is not None: if best_metric is not None:
print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch( def train_epoch(
@ -394,21 +394,22 @@ def train_epoch(
losses_m.update(reduced_loss.item(), input.size(0)) losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0: if args.local_rank == 0:
print('Train: {} [{}/{} ({:.0f}%)] ' logging.info(
'Loss: {loss.val:.6f} ({loss.avg:.4f}) ' 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Time: {batch_time.val:.3f}s, {rate:.3f}/s ' 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'({batch_time.avg:.3f}s, {rate_avg:.3f}/s) ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'LR: {lr:.4f} ' '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 'LR: {lr:.3e} '
epoch, 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
batch_idx, len(loader), epoch,
100. * batch_idx / last_idx, batch_idx, len(loader),
loss=losses_m, 100. * batch_idx / last_idx,
batch_time=batch_time_m, loss=losses_m,
rate=input.size(0) * args.world_size / batch_time_m.val, batch_time=batch_time_m,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg, rate=input.size(0) * args.world_size / batch_time_m.val,
lr=lr, rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
data_time=data_time_m)) lr=lr,
data_time=data_time_m))
if args.save_images and output_dir: if args.save_images and output_dir:
torchvision.utils.save_image( torchvision.utils.save_image(
@ -478,14 +479,15 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
end = time.time() end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix log_name = 'Test' + log_suffix
print('{0}: [{1}/{2}]\t' logging.info(
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' '{0}: [{1:>4d}/{2}] '
'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format( 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
log_name, batch_idx, last_idx, 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
batch_time=batch_time_m, loss=losses_m, log_name, batch_idx, last_idx,
top1=prec1_m, top5=prec5_m)) batch_time=batch_time_m, loss=losses_m,
top1=prec1_m, top5=prec5_m))
metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)]) metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])

@ -7,6 +7,7 @@ import os
import csv import csv
import glob import glob
import time import time
import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
@ -14,7 +15,7 @@ from collections import OrderedDict
from timm.models import create_model, apply_test_time_pool, load_checkpoint from timm.models import create_model, apply_test_time_pool, load_checkpoint
from timm.data import Dataset, create_loader, resolve_data_config from timm.data import Dataset, create_loader, resolve_data_config
from timm.utils import accuracy, AverageMeter, natural_key from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -37,8 +38,8 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)') help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000, parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset') help='Number classes in dataset')
parser.add_argument('--print-freq', '-p', default=10, type=int, parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='print frequency (default: 10)') metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', parser.add_argument('--pretrained', dest='pretrained', action='store_true',
@ -68,7 +69,7 @@ def validate(args):
load_checkpoint(model, args.checkpoint, args.use_ema) load_checkpoint(model, args.checkpoint, args.use_ema)
param_count = sum([m.numel() for m in model.parameters()]) param_count = sum([m.numel() for m in model.parameters()])
print('Model %s created, param count: %d' % (args.model, param_count)) logging.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(model, args) data_config = resolve_data_config(model, args)
model, test_time_pool = apply_test_time_pool(model, data_config, args) model, test_time_pool = apply_test_time_pool(model, data_config, args)
@ -118,28 +119,30 @@ def validate(args):
batch_time.update(time.time() - end) batch_time.update(time.time() - end)
end = time.time() end = time.time()
if i % args.print_freq == 0: if i % args.log_freq == 0:
print('Test: [{0}/{1}]\t' logging.info(
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t' 'Test: [{0:>4d}/{1}] '
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
i, len(loader), batch_time=batch_time, 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
rate_avg=input.size(0) / batch_time.avg, i, len(loader), batch_time=batch_time,
loss=losses, top1=top1, top5=top5)) rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
results = OrderedDict( results = OrderedDict(
top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3), top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3), top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
param_count=round(param_count / 1e6, 2)) param_count=round(param_count / 1e6, 2))
print(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err'])) results['top1'], results['top1_err'], results['top5'], results['top5_err']))
return results return results
def main(): def main():
setup_default_logging()
args = parser.parse_args() args = parser.parse_args()
if args.model == 'all': if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints # validate all models in a list of names with pretrained checkpoints

Loading…
Cancel
Save