diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 7cc7b0b0..9f62a7d5 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -1,6 +1,6 @@ from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ rand_augment_transform, auto_augment_transform -from .config import resolve_data_config +from .config import resolve_data_config, resolve_model_data_config from .constants import * from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset diff --git a/timm/data/config.py b/timm/data/config.py index a65695d0..a6c2298c 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -6,16 +6,18 @@ _logger = logging.getLogger(__name__) def resolve_data_config( - args, - default_cfg=None, + args=None, + pretrained_cfg=None, model=None, use_test_size=False, verbose=False ): - new_config = {} - default_cfg = default_cfg or {} - if not default_cfg and model is not None and hasattr(model, 'default_cfg'): - default_cfg = model.default_cfg + assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config." + args = args or {} + pretrained_cfg = pretrained_cfg or {} + if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'): + pretrained_cfg = model.pretrained_cfg + data_config = {} # Resolve input/image size in_chans = 3 @@ -32,65 +34,94 @@ def resolve_data_config( assert isinstance(args['img_size'], int) input_size = (in_chans, args['img_size'], args['img_size']) else: - if use_test_size and default_cfg.get('test_input_size', None) is not None: - input_size = default_cfg['test_input_size'] - elif default_cfg.get('input_size', None) is not None: - input_size = default_cfg['input_size'] - new_config['input_size'] = input_size + if use_test_size and pretrained_cfg.get('test_input_size', None) is not None: + input_size = pretrained_cfg['test_input_size'] + elif pretrained_cfg.get('input_size', None) is not None: + input_size = pretrained_cfg['input_size'] + data_config['input_size'] = input_size # resolve interpolation method - new_config['interpolation'] = 'bicubic' + data_config['interpolation'] = 'bicubic' if args.get('interpolation', None): - new_config['interpolation'] = args['interpolation'] - elif default_cfg.get('interpolation', None): - new_config['interpolation'] = default_cfg['interpolation'] + data_config['interpolation'] = args['interpolation'] + elif pretrained_cfg.get('interpolation', None): + data_config['interpolation'] = pretrained_cfg['interpolation'] # resolve dataset + model mean for normalization - new_config['mean'] = IMAGENET_DEFAULT_MEAN + data_config['mean'] = IMAGENET_DEFAULT_MEAN if args.get('mean', None) is not None: mean = tuple(args['mean']) if len(mean) == 1: mean = tuple(list(mean) * in_chans) else: assert len(mean) == in_chans - new_config['mean'] = mean - elif default_cfg.get('mean', None): - new_config['mean'] = default_cfg['mean'] + data_config['mean'] = mean + elif pretrained_cfg.get('mean', None): + data_config['mean'] = pretrained_cfg['mean'] # resolve dataset + model std deviation for normalization - new_config['std'] = IMAGENET_DEFAULT_STD + data_config['std'] = IMAGENET_DEFAULT_STD if args.get('std', None) is not None: std = tuple(args['std']) if len(std) == 1: std = tuple(list(std) * in_chans) else: assert len(std) == in_chans - new_config['std'] = std - elif default_cfg.get('std', None): - new_config['std'] = default_cfg['std'] + data_config['std'] = std + elif pretrained_cfg.get('std', None): + data_config['std'] = pretrained_cfg['std'] # resolve default inference crop crop_pct = DEFAULT_CROP_PCT if args.get('crop_pct', None): crop_pct = args['crop_pct'] else: - if use_test_size and default_cfg.get('test_crop_pct', None): - crop_pct = default_cfg['test_crop_pct'] - elif default_cfg.get('crop_pct', None): - crop_pct = default_cfg['crop_pct'] - new_config['crop_pct'] = crop_pct + if use_test_size and pretrained_cfg.get('test_crop_pct', None): + crop_pct = pretrained_cfg['test_crop_pct'] + elif pretrained_cfg.get('crop_pct', None): + crop_pct = pretrained_cfg['crop_pct'] + data_config['crop_pct'] = crop_pct # resolve default crop percentage crop_mode = DEFAULT_CROP_MODE if args.get('crop_mode', None): crop_mode = args['crop_mode'] - elif default_cfg.get('crop_mode', None): - crop_mode = default_cfg['crop_mode'] - new_config['crop_mode'] = crop_mode + elif pretrained_cfg.get('crop_mode', None): + crop_mode = pretrained_cfg['crop_mode'] + data_config['crop_mode'] = crop_mode if verbose: _logger.info('Data processing configuration for current model + dataset:') - for n, v in new_config.items(): + for n, v in data_config.items(): _logger.info('\t%s: %s' % (n, str(v))) - return new_config + return data_config + + +def resolve_model_data_config( + model, + args=None, + pretrained_cfg=None, + use_test_size=False, + verbose=False, +): + """ Resolve Model Data Config + This is equivalent to resolve_data_config() but with arguments re-ordered to put model first. + + Args: + model (nn.Module): the model instance + args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg) + pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model) + use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution + verbose (bool): enable extra logging of resolved values + + Returns: + dictionary of config + """ + return resolve_data_config( + args=args, + pretrained_cfg=pretrained_cfg, + model=model, + use_test_size=use_test_size, + verbose=verbose, + )