Leverage python hierachical logger

with this update one can tune the kind of logs generated by timm but
training and inference traces are unchanged
pull/163/head
Antoine Broyelle 5 years ago
parent 5966654052
commit 78fa0772cc

@ -17,6 +17,8 @@ from timm.data import Dataset, create_loader, resolve_data_config
from timm.utils import AverageMeter, setup_default_logging from timm.utils import AverageMeter, setup_default_logging
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
parser.add_argument('data', metavar='DIR', parser.add_argument('data', metavar='DIR',
@ -67,7 +69,7 @@ def main():
pretrained=args.pretrained, pretrained=args.pretrained,
checkpoint_path=args.checkpoint) checkpoint_path=args.checkpoint)
logging.info('Model %s created, param count: %d' % logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
config = resolve_data_config(vars(args), model=model) config = resolve_data_config(vars(args), model=model)
@ -107,7 +109,7 @@ def main():
end = time.time() end = time.time()
if batch_idx % args.log_freq == 0: if batch_idx % args.log_freq == 0:
logging.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
batch_idx, len(loader), batch_time=batch_time)) batch_idx, len(loader), batch_time=batch_time))
topk_ids = np.concatenate(topk_ids, axis=0).squeeze() topk_ids = np.concatenate(topk_ids, axis=0).squeeze()

@ -2,6 +2,9 @@ import logging
from .constants import * from .constants import *
logger = logging.getLogger(__name__)
def resolve_data_config(args, default_cfg={}, model=None, verbose=True): def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
new_config = {} new_config = {}
default_cfg = default_cfg default_cfg = default_cfg
@ -65,8 +68,8 @@ def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
new_config['crop_pct'] = default_cfg['crop_pct'] new_config['crop_pct'] = default_cfg['crop_pct']
if verbose: if verbose:
logging.info('Data processing configuration for current model + dataset:') logger.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items(): for n, v in new_config.items():
logging.info('\t%s: %s' % (n, str(v))) logger.info('\t%s: %s' % (n, str(v)))
return new_config return new_config

@ -10,6 +10,9 @@ from .layers.activations import HardSwish, Swish
from .efficientnet_blocks import * from .efficientnet_blocks import *
logger = logging.getLogger(__name__)
def _parse_ksize(ss): def _parse_ksize(ss):
if ss.isdigit(): if ss.isdigit():
return int(ss) return int(ss)
@ -246,7 +249,7 @@ class EfficientNetBuilder:
ba['drop_path_rate'] = drop_path_rate ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs ba['se_kwargs'] = self.se_kwargs
if self.verbose: if self.verbose:
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba))) logger.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
if ba.get('num_experts', 0) > 0: if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba) block = CondConvResidual(**ba)
else: else:
@ -255,17 +258,17 @@ class EfficientNetBuilder:
ba['drop_path_rate'] = drop_path_rate ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs ba['se_kwargs'] = self.se_kwargs
if self.verbose: if self.verbose:
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba))) logger.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
block = DepthwiseSeparableConv(**ba) block = DepthwiseSeparableConv(**ba)
elif bt == 'er': elif bt == 'er':
ba['drop_path_rate'] = drop_path_rate ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs ba['se_kwargs'] = self.se_kwargs
if self.verbose: if self.verbose:
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba))) logger.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
block = EdgeResidual(**ba) block = EdgeResidual(**ba)
elif bt == 'cn': elif bt == 'cn':
if self.verbose: if self.verbose:
logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba))) logger.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)))
block = ConvBnAct(**ba) block = ConvBnAct(**ba)
else: else:
assert False, 'Uknkown block type (%s) while building model.' % bt assert False, 'Uknkown block type (%s) while building model.' % bt
@ -283,7 +286,7 @@ class EfficientNetBuilder:
List of block stacks (each stack wrapped in nn.Sequential) List of block stacks (each stack wrapped in nn.Sequential)
""" """
if self.verbose: if self.verbose:
logging.info('Building model trunk with %d stages...' % len(model_block_args)) logger.info('Building model trunk with %d stages...' % len(model_block_args))
self.in_chs = in_chs self.in_chs = in_chs
total_block_count = sum([len(x) for x in model_block_args]) total_block_count = sum([len(x) for x in model_block_args])
total_block_idx = 0 total_block_idx = 0
@ -295,7 +298,7 @@ class EfficientNetBuilder:
for stage_idx, stage_block_args in enumerate(model_block_args): for stage_idx, stage_block_args in enumerate(model_block_args):
last_stack = stage_idx == (len(model_block_args) - 1) last_stack = stage_idx == (len(model_block_args) - 1)
if self.verbose: if self.verbose:
logging.info('Stack: {}'.format(stage_idx)) logger.info('Stack: {}'.format(stage_idx))
assert isinstance(stage_block_args, list) assert isinstance(stage_block_args, list)
blocks = [] blocks = []
@ -304,7 +307,7 @@ class EfficientNetBuilder:
last_block = block_idx == (len(stage_block_args) - 1) last_block = block_idx == (len(stage_block_args) - 1)
extract_features = '' # No features extracted extract_features = '' # No features extracted
if self.verbose: if self.verbose:
logging.info(' Block: {}'.format(block_idx)) logger.info(' Block: {}'.format(block_idx))
# Sort out stride, dilation, and feature extraction details # Sort out stride, dilation, and feature extraction details
assert block_args['stride'] in (1, 2) assert block_args['stride'] in (1, 2)
@ -334,7 +337,7 @@ class EfficientNetBuilder:
next_dilation = current_dilation * block_args['stride'] next_dilation = current_dilation * block_args['stride']
block_args['stride'] = 1 block_args['stride'] = 1
if self.verbose: if self.verbose:
logging.info(' Converting stride to dilation to maintain output_stride=={}'.format( logger.info(' Converting stride to dilation to maintain output_stride=={}'.format(
self.output_stride)) self.output_stride))
else: else:
current_stride = next_output_stride current_stride = next_output_stride

@ -8,6 +8,9 @@ from collections import OrderedDict
from timm.models.layers.conv2d_same import Conv2dSame from timm.models.layers.conv2d_same import Conv2dSame
logger = logging.getLogger(__name__)
def load_state_dict(checkpoint_path, use_ema=False): def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
@ -24,10 +27,10 @@ def load_state_dict(checkpoint_path, use_ema=False):
state_dict = new_state_dict state_dict = new_state_dict
else: else:
state_dict = checkpoint state_dict = checkpoint
logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
return state_dict return state_dict
else: else:
logging.error("No checkpoint found at '{}'".format(checkpoint_path)) logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()
@ -55,13 +58,13 @@ def resume_checkpoint(model, checkpoint_path):
resume_epoch = checkpoint['epoch'] resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1: if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
logging.info("Loaded checkpoint '{}'".format(checkpoint_path)) logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
return other_state, resume_epoch return other_state, resume_epoch
else: else:
logging.error("No checkpoint found at '{}'".format(checkpoint_path)) logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()
@ -69,14 +72,14 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
if cfg is None: if cfg is None:
cfg = getattr(model, 'default_cfg') cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']: if cfg is None or 'url' not in cfg or not cfg['url']:
logging.warning("Pretrained model URL is invalid, using random initialization.") logger.warning("Pretrained model URL is invalid, using random initialization.")
return return
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
if in_chans == 1: if in_chans == 1:
conv1_name = cfg['first_conv'] conv1_name = cfg['first_conv']
logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name) logger.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
conv1_weight = state_dict[conv1_name + '.weight'] conv1_weight = state_dict[conv1_name + '.weight']
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True) state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
elif in_chans != 3: elif in_chans != 3:

@ -9,6 +9,9 @@ import torch.nn.functional as F
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
logger = logging.getLogger(__name__)
class TestTimePoolHead(nn.Module): class TestTimePoolHead(nn.Module):
def __init__(self, base, original_pool=7): def __init__(self, base, original_pool=7):
super(TestTimePoolHead, self).__init__() super(TestTimePoolHead, self).__init__()
@ -39,7 +42,7 @@ def apply_test_time_pool(model, config, args):
if not args.no_test_pool and \ if not args.no_test_pool and \
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \ config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
config['input_size'][-2] > model.default_cfg['input_size'][-2]: config['input_size'][-2] > model.default_cfg['input_size'][-2]:
logging.info('Target input size %s > pretrained default %s, using test time pooling' % logger.info('Target input size %s > pretrained default %s, using test time pooling' %
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:]))) (str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
test_time_pool = True test_time_pool = True

@ -21,6 +21,9 @@ except ImportError:
from torch import distributed as dist from torch import distributed as dist
logger = logging.getLogger(__name__)
def unwrap_model(model): def unwrap_model(model):
if isinstance(model, ModelEma): if isinstance(model, ModelEma):
return unwrap_model(model.ema) return unwrap_model(model.ema)
@ -84,7 +87,7 @@ class CheckpointSaver:
checkpoints_str = "Current checkpoints:\n" checkpoints_str = "Current checkpoints:\n"
for c in self.checkpoint_files: for c in self.checkpoint_files:
checkpoints_str += ' {}\n'.format(c) checkpoints_str += ' {}\n'.format(c)
logging.info(checkpoints_str) logger.info(checkpoints_str)
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
self.best_epoch = epoch self.best_epoch = epoch
@ -121,10 +124,10 @@ class CheckpointSaver:
to_delete = self.checkpoint_files[delete_index:] to_delete = self.checkpoint_files[delete_index:]
for d in to_delete: for d in to_delete:
try: try:
logging.debug("Cleaning checkpoint: {}".format(d)) logger.debug("Cleaning checkpoint: {}".format(d))
os.remove(d[0]) os.remove(d[0])
except Exception as e: except Exception as e:
logging.error("Exception '{}' while deleting checkpoint".format(e)) logger.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index] self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0): def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0):
@ -134,10 +137,10 @@ class CheckpointSaver:
self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp) self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp)
if os.path.exists(self.last_recovery_file): if os.path.exists(self.last_recovery_file):
try: try:
logging.debug("Cleaning recovery: {}".format(self.last_recovery_file)) logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
os.remove(self.last_recovery_file) os.remove(self.last_recovery_file)
except Exception as e: except Exception as e:
logging.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
self.last_recovery_file = self.curr_recovery_file self.last_recovery_file = self.curr_recovery_file
self.curr_recovery_file = save_path self.curr_recovery_file = save_path
@ -279,9 +282,9 @@ class ModelEma:
name = k name = k
new_state_dict[name] = v new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict) self.ema.load_state_dict(new_state_dict)
logging.info("Loaded state_dict_ema") logger.info("Loaded state_dict_ema")
else: else:
logging.warning("Failed to find state_dict_ema, starting from loaded model weights") logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
def update(self, model): def update(self, model):
# correct a mismatch in state dict keys # correct a mismatch in state dict keys

@ -40,6 +40,7 @@ import torch.nn as nn
import torchvision.utils import torchvision.utils
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
logger = logging.getLogger(__name__)
# The first arg parser parses out only the --config argument, this argument is used to # The first arg parser parses out only the --config argument, this argument is used to
@ -228,7 +229,7 @@ def main():
if 'WORLD_SIZE' in os.environ: if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1: if args.distributed and args.num_gpu > 1:
logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') logger.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
args.num_gpu = 1 args.num_gpu = 1
args.device = 'cuda:0' args.device = 'cuda:0'
@ -244,10 +245,10 @@ def main():
assert args.rank >= 0 assert args.rank >= 0
if args.distributed: if args.distributed:
logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size)) % (args.rank, args.world_size))
else: else:
logging.info('Training with a single process on %d GPUs.' % args.num_gpu) logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
torch.manual_seed(args.seed + args.rank) torch.manual_seed(args.seed + args.rank)
@ -266,7 +267,7 @@ def main():
checkpoint_path=args.initial_checkpoint) checkpoint_path=args.initial_checkpoint)
if args.local_rank == 0: if args.local_rank == 0:
logging.info('Model %s created, param count: %d' % logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
@ -282,7 +283,7 @@ def main():
if args.num_gpu > 1: if args.num_gpu > 1:
if args.amp: if args.amp:
logging.warning( logger.warning(
'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.') 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
args.amp = False args.amp = False
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
@ -296,7 +297,7 @@ def main():
model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
use_amp = True use_amp = True
if args.local_rank == 0: if args.local_rank == 0:
logging.info('NVIDIA APEX {}. AMP {}.'.format( logger.info('NVIDIA APEX {}. AMP {}.'.format(
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
# optionally resume from a checkpoint # optionally resume from a checkpoint
@ -307,11 +308,11 @@ def main():
if resume_state and not args.no_resume_opt: if resume_state and not args.no_resume_opt:
if 'optimizer' in resume_state: if 'optimizer' in resume_state:
if args.local_rank == 0: if args.local_rank == 0:
logging.info('Restoring Optimizer state from checkpoint') logger.info('Restoring Optimizer state from checkpoint')
optimizer.load_state_dict(resume_state['optimizer']) optimizer.load_state_dict(resume_state['optimizer'])
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
if args.local_rank == 0: if args.local_rank == 0:
logging.info('Restoring NVIDIA AMP state from checkpoint') logger.info('Restoring NVIDIA AMP state from checkpoint')
amp.load_state_dict(resume_state['amp']) amp.load_state_dict(resume_state['amp'])
del resume_state del resume_state
@ -333,16 +334,16 @@ def main():
else: else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0: if args.local_rank == 0:
logging.info( logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e: except Exception as e:
logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex: if has_apex:
model = DDP(model, delay_allreduce=True) model = DDP(model, delay_allreduce=True)
else: else:
if args.local_rank == 0: if args.local_rank == 0:
logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") logger.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
# NOTE: EMA model does not need to be wrapped by DDP # NOTE: EMA model does not need to be wrapped by DDP
@ -357,11 +358,11 @@ def main():
lr_scheduler.step(start_epoch) lr_scheduler.step(start_epoch)
if args.local_rank == 0: if args.local_rank == 0:
logging.info('Scheduled epochs: {}'.format(num_epochs)) logger.info('Scheduled epochs: {}'.format(num_epochs))
train_dir = os.path.join(args.data, 'train') train_dir = os.path.join(args.data, 'train')
if not os.path.exists(train_dir): if not os.path.exists(train_dir):
logging.error('Training folder does not exist at: {}'.format(train_dir)) logger.error('Training folder does not exist at: {}'.format(train_dir))
exit(1) exit(1)
dataset_train = Dataset(train_dir) dataset_train = Dataset(train_dir)
@ -400,7 +401,7 @@ def main():
if not os.path.isdir(eval_dir): if not os.path.isdir(eval_dir):
eval_dir = os.path.join(args.data, 'validation') eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir): if not os.path.isdir(eval_dir):
logging.error('Validation folder does not exist at: {}'.format(eval_dir)) logger.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1) exit(1)
dataset_eval = Dataset(eval_dir) dataset_eval = Dataset(eval_dir)
@ -464,7 +465,7 @@ def main():
if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0: if args.local_rank == 0:
logging.info("Distributing BatchNorm running means and vars") logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce') distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args) eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
@ -495,7 +496,7 @@ def main():
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
if best_metric is not None: if best_metric is not None:
logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch( def train_epoch(
@ -555,7 +556,7 @@ def train_epoch(
losses_m.update(reduced_loss.item(), input.size(0)) losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0: if args.local_rank == 0:
logging.info( logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
@ -643,7 +644,7 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
end = time.time() end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix log_name = 'Test' + log_suffix
logging.info( logger.info(
'{0}: [{1:>4d}/{2}] ' '{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '

@ -29,6 +29,8 @@ from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', metavar='DIR', parser.add_argument('data', metavar='DIR',
@ -95,7 +97,7 @@ def validate(args):
load_checkpoint(model, args.checkpoint, args.use_ema) load_checkpoint(model, args.checkpoint, args.use_ema)
param_count = sum([m.numel() for m in model.parameters()]) param_count = sum([m.numel() for m in model.parameters()])
logging.info('Model %s created, param count: %d' % (args.model, param_count)) logger.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(vars(args), model=model) data_config = resolve_data_config(vars(args), model=model)
model, test_time_pool = apply_test_time_pool(model, data_config, args) model, test_time_pool = apply_test_time_pool(model, data_config, args)
@ -165,7 +167,7 @@ def validate(args):
end = time.time() end = time.time()
if i % args.log_freq == 0: if i % args.log_freq == 0:
logging.info( logger.info(
'Test: [{0:>4d}/{1}] ' 'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
@ -183,7 +185,7 @@ def validate(args):
cropt_pct=crop_pct, cropt_pct=crop_pct,
interpolation=data_config['interpolation']) interpolation=data_config['interpolation'])
logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err'])) results['top1'], results['top1_err'], results['top5'], results['top5_err']))
return results return results
@ -213,7 +215,7 @@ def main():
if len(model_cfgs): if len(model_cfgs):
results_file = args.results_file or './results-all.csv' results_file = args.results_file or './results-all.csv'
logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
results = [] results = []
try: try:
start_batch_size = args.batch_size start_batch_size = args.batch_size

Loading…
Cancel
Save