From 78fa0772cc76e629ed3f39026e3b0faceee5c995 Mon Sep 17 00:00:00 2001 From: Antoine Broyelle <=> Date: Tue, 9 Jun 2020 18:28:48 +0100 Subject: [PATCH] Leverage python hierachical logger with this update one can tune the kind of logs generated by timm but training and inference traces are unchanged --- inference.py | 6 +++-- timm/data/config.py | 7 ++++-- timm/models/efficientnet_builder.py | 19 ++++++++------ timm/models/helpers.py | 17 +++++++------ timm/models/layers/test_time_pool.py | 5 +++- timm/utils.py | 17 +++++++------ train.py | 37 ++++++++++++++-------------- validate.py | 10 +++++--- 8 files changed, 69 insertions(+), 49 deletions(-) diff --git a/inference.py b/inference.py index 3ee994a6..a49af5f2 100755 --- a/inference.py +++ b/inference.py @@ -17,6 +17,8 @@ from timm.data import Dataset, create_loader, resolve_data_config from timm.utils import AverageMeter, setup_default_logging torch.backends.cudnn.benchmark = True +logger = logging.getLogger(__name__) + parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') parser.add_argument('data', metavar='DIR', @@ -67,7 +69,7 @@ def main(): pretrained=args.pretrained, checkpoint_path=args.checkpoint) - logging.info('Model %s created, param count: %d' % + logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) config = resolve_data_config(vars(args), model=model) @@ -107,7 +109,7 @@ def main(): end = time.time() if batch_idx % args.log_freq == 0: - logging.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( + logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( batch_idx, len(loader), batch_time=batch_time)) topk_ids = np.concatenate(topk_ids, axis=0).squeeze() diff --git a/timm/data/config.py b/timm/data/config.py index dbae0da7..564ea9b4 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -2,6 +2,9 @@ import logging from .constants import * +logger = logging.getLogger(__name__) + + def resolve_data_config(args, default_cfg={}, model=None, verbose=True): new_config = {} default_cfg = default_cfg @@ -65,8 +68,8 @@ def resolve_data_config(args, default_cfg={}, model=None, verbose=True): new_config['crop_pct'] = default_cfg['crop_pct'] if verbose: - logging.info('Data processing configuration for current model + dataset:') + logger.info('Data processing configuration for current model + dataset:') for n, v in new_config.items(): - logging.info('\t%s: %s' % (n, str(v))) + logger.info('\t%s: %s' % (n, str(v))) return new_config diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 842098cf..9952eb64 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -10,6 +10,9 @@ from .layers.activations import HardSwish, Swish from .efficientnet_blocks import * +logger = logging.getLogger(__name__) + + def _parse_ksize(ss): if ss.isdigit(): return int(ss) @@ -246,7 +249,7 @@ class EfficientNetBuilder: ba['drop_path_rate'] = drop_path_rate ba['se_kwargs'] = self.se_kwargs if self.verbose: - logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba))) + logger.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba))) if ba.get('num_experts', 0) > 0: block = CondConvResidual(**ba) else: @@ -255,17 +258,17 @@ class EfficientNetBuilder: ba['drop_path_rate'] = drop_path_rate ba['se_kwargs'] = self.se_kwargs if self.verbose: - logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba))) + logger.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba))) block = DepthwiseSeparableConv(**ba) elif bt == 'er': ba['drop_path_rate'] = drop_path_rate ba['se_kwargs'] = self.se_kwargs if self.verbose: - logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba))) + logger.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba))) block = EdgeResidual(**ba) elif bt == 'cn': if self.verbose: - logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba))) + logger.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba))) block = ConvBnAct(**ba) else: assert False, 'Uknkown block type (%s) while building model.' % bt @@ -283,7 +286,7 @@ class EfficientNetBuilder: List of block stacks (each stack wrapped in nn.Sequential) """ if self.verbose: - logging.info('Building model trunk with %d stages...' % len(model_block_args)) + logger.info('Building model trunk with %d stages...' % len(model_block_args)) self.in_chs = in_chs total_block_count = sum([len(x) for x in model_block_args]) total_block_idx = 0 @@ -295,7 +298,7 @@ class EfficientNetBuilder: for stage_idx, stage_block_args in enumerate(model_block_args): last_stack = stage_idx == (len(model_block_args) - 1) if self.verbose: - logging.info('Stack: {}'.format(stage_idx)) + logger.info('Stack: {}'.format(stage_idx)) assert isinstance(stage_block_args, list) blocks = [] @@ -304,7 +307,7 @@ class EfficientNetBuilder: last_block = block_idx == (len(stage_block_args) - 1) extract_features = '' # No features extracted if self.verbose: - logging.info(' Block: {}'.format(block_idx)) + logger.info(' Block: {}'.format(block_idx)) # Sort out stride, dilation, and feature extraction details assert block_args['stride'] in (1, 2) @@ -334,7 +337,7 @@ class EfficientNetBuilder: next_dilation = current_dilation * block_args['stride'] block_args['stride'] = 1 if self.verbose: - logging.info(' Converting stride to dilation to maintain output_stride=={}'.format( + logger.info(' Converting stride to dilation to maintain output_stride=={}'.format( self.output_stride)) else: current_stride = next_output_stride diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 3baad3bf..ad594872 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -8,6 +8,9 @@ from collections import OrderedDict from timm.models.layers.conv2d_same import Conv2dSame +logger = logging.getLogger(__name__) + + def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') @@ -24,10 +27,10 @@ def load_state_dict(checkpoint_path, use_ema=False): state_dict = new_state_dict else: state_dict = checkpoint - logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) + logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) return state_dict else: - logging.error("No checkpoint found at '{}'".format(checkpoint_path)) + logger.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() @@ -55,13 +58,13 @@ def resume_checkpoint(model, checkpoint_path): resume_epoch = checkpoint['epoch'] if 'version' in checkpoint and checkpoint['version'] > 1: resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save - logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) else: model.load_state_dict(checkpoint) - logging.info("Loaded checkpoint '{}'".format(checkpoint_path)) + logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) return other_state, resume_epoch else: - logging.error("No checkpoint found at '{}'".format(checkpoint_path)) + logger.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() @@ -69,14 +72,14 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non if cfg is None: cfg = getattr(model, 'default_cfg') if cfg is None or 'url' not in cfg or not cfg['url']: - logging.warning("Pretrained model URL is invalid, using random initialization.") + logger.warning("Pretrained model URL is invalid, using random initialization.") return state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') if in_chans == 1: conv1_name = cfg['first_conv'] - logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name) + logger.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name) conv1_weight = state_dict[conv1_name + '.weight'] state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True) elif in_chans != 3: diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py index dcfc66ca..d2def43f 100644 --- a/timm/models/layers/test_time_pool.py +++ b/timm/models/layers/test_time_pool.py @@ -9,6 +9,9 @@ import torch.nn.functional as F from .adaptive_avgmax_pool import adaptive_avgmax_pool2d +logger = logging.getLogger(__name__) + + class TestTimePoolHead(nn.Module): def __init__(self, base, original_pool=7): super(TestTimePoolHead, self).__init__() @@ -39,7 +42,7 @@ def apply_test_time_pool(model, config, args): if not args.no_test_pool and \ config['input_size'][-1] > model.default_cfg['input_size'][-1] and \ config['input_size'][-2] > model.default_cfg['input_size'][-2]: - logging.info('Target input size %s > pretrained default %s, using test time pooling' % + logger.info('Target input size %s > pretrained default %s, using test time pooling' % (str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:]))) model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) test_time_pool = True diff --git a/timm/utils.py b/timm/utils.py index 2cae024d..c6dc95bb 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -21,6 +21,9 @@ except ImportError: from torch import distributed as dist +logger = logging.getLogger(__name__) + + def unwrap_model(model): if isinstance(model, ModelEma): return unwrap_model(model.ema) @@ -84,7 +87,7 @@ class CheckpointSaver: checkpoints_str = "Current checkpoints:\n" for c in self.checkpoint_files: checkpoints_str += ' {}\n'.format(c) - logging.info(checkpoints_str) + logger.info(checkpoints_str) if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): self.best_epoch = epoch @@ -121,10 +124,10 @@ class CheckpointSaver: to_delete = self.checkpoint_files[delete_index:] for d in to_delete: try: - logging.debug("Cleaning checkpoint: {}".format(d)) + logger.debug("Cleaning checkpoint: {}".format(d)) os.remove(d[0]) except Exception as e: - logging.error("Exception '{}' while deleting checkpoint".format(e)) + logger.error("Exception '{}' while deleting checkpoint".format(e)) self.checkpoint_files = self.checkpoint_files[:delete_index] def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0): @@ -134,10 +137,10 @@ class CheckpointSaver: self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp) if os.path.exists(self.last_recovery_file): try: - logging.debug("Cleaning recovery: {}".format(self.last_recovery_file)) + logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) os.remove(self.last_recovery_file) except Exception as e: - logging.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) + logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) self.last_recovery_file = self.curr_recovery_file self.curr_recovery_file = save_path @@ -279,9 +282,9 @@ class ModelEma: name = k new_state_dict[name] = v self.ema.load_state_dict(new_state_dict) - logging.info("Loaded state_dict_ema") + logger.info("Loaded state_dict_ema") else: - logging.warning("Failed to find state_dict_ema, starting from loaded model weights") + logger.warning("Failed to find state_dict_ema, starting from loaded model weights") def update(self, model): # correct a mismatch in state dict keys diff --git a/train.py b/train.py index 899c6984..c51eb2d4 100755 --- a/train.py +++ b/train.py @@ -40,6 +40,7 @@ import torch.nn as nn import torchvision.utils torch.backends.cudnn.benchmark = True +logger = logging.getLogger(__name__) # The first arg parser parses out only the --config argument, this argument is used to @@ -228,7 +229,7 @@ def main(): if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: - logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') + logger.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') args.num_gpu = 1 args.device = 'cuda:0' @@ -244,10 +245,10 @@ def main(): assert args.rank >= 0 if args.distributed: - logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' + logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: - logging.info('Training with a single process on %d GPUs.' % args.num_gpu) + logger.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) @@ -266,7 +267,7 @@ def main(): checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: - logging.info('Model %s created, param count: %d' % + logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) @@ -282,7 +283,7 @@ def main(): if args.num_gpu > 1: if args.amp: - logging.warning( + logger.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.') args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() @@ -296,7 +297,7 @@ def main(): model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: - logging.info('NVIDIA APEX {}. AMP {}.'.format( + logger.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) # optionally resume from a checkpoint @@ -307,11 +308,11 @@ def main(): if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: - logging.info('Restoring Optimizer state from checkpoint') + logger.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: - logging.info('Restoring NVIDIA AMP state from checkpoint') + logger.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) del resume_state @@ -333,16 +334,16 @@ def main(): else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: - logging.info( + logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') except Exception as e: - logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') + logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: - logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") + logger.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP @@ -357,11 +358,11 @@ def main(): lr_scheduler.step(start_epoch) if args.local_rank == 0: - logging.info('Scheduled epochs: {}'.format(num_epochs)) + logger.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): - logging.error('Training folder does not exist at: {}'.format(train_dir)) + logger.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) @@ -400,7 +401,7 @@ def main(): if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): - logging.error('Validation folder does not exist at: {}'.format(eval_dir)) + logger.error('Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) @@ -464,7 +465,7 @@ def main(): if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: - logging.info("Distributing BatchNorm running means and vars") + logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args) @@ -495,7 +496,7 @@ def main(): except KeyboardInterrupt: pass if best_metric is not None: - logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) + logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) def train_epoch( @@ -555,7 +556,7 @@ def train_epoch( losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: - logging.info( + logger.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' @@ -643,7 +644,7 @@ def validate(model, loader, loss_fn, args, log_suffix=''): end = time.time() if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix - logging.info( + logger.info( '{0}: [{1:>4d}/{2}] ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' diff --git a/validate.py b/validate.py index f8ac7c55..0962139a 100755 --- a/validate.py +++ b/validate.py @@ -29,6 +29,8 @@ from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging torch.backends.cudnn.benchmark = True +logger = logging.getLogger(__name__) + parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser.add_argument('data', metavar='DIR', @@ -95,7 +97,7 @@ def validate(args): load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) - logging.info('Model %s created, param count: %d' % (args.model, param_count)) + logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) @@ -165,7 +167,7 @@ def validate(args): end = time.time() if i % args.log_freq == 0: - logging.info( + logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' @@ -183,7 +185,7 @@ def validate(args): cropt_pct=crop_pct, interpolation=data_config['interpolation']) - logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( + logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results @@ -213,7 +215,7 @@ def main(): if len(model_cfgs): results_file = args.results_file or './results-all.csv' - logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) + logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) results = [] try: start_batch_size = args.batch_size